From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 7 Dec 2020 22:54:04 +0100 Subject: Segmentation working! --- src/text_recognizer/models/base.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'src/text_recognizer/models/base.py') diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d394b4c..f2cd4b8 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -159,7 +159,7 @@ class Model(ABC): self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) self.test_dataset.load_or_generate_data() - # Set the flag to true to disable ability to load data agian. + # Set the flag to true to disable ability to load data again. self.data_prepared = True def train_dataloader(self) -> DataLoader: @@ -260,7 +260,7 @@ class Model(ABC): @property def mapping(self) -> Dict: """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping + return self._mapper.mapping if self._mapper is not None else None def eval(self) -> None: """Sets the network to evaluation mode.""" @@ -341,7 +341,7 @@ class Model(ABC): if input_shape is not None: summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: - input_shape = (1,) + tuple(self._input_shape) + input_shape = tuple(self._input_shape) summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -427,7 +427,7 @@ class Model(ABC): ) shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - def load_weights(self, network_fn: Type[nn.Module]) -> None: + def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: """Load the network weights.""" logger.debug("Loading network with pretrained weights.") filename = glob(self.weights_filename)[0] @@ -441,7 +441,8 @@ class Model(ABC): weights = state_dict["model_state"] # Initializes the network with trained weights. - self._network = network_fn(**self._network_args) + if network_fn is not None: + self._network = network_fn(**self._network_args) self._network.load_state_dict(weights) if "swa_network" in state_dict: -- cgit v1.2.3-70-g09d2