diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 | 
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 | 
| commit | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch) | |
| tree | 526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/text_recognizer/models/base.py | |
| parent | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff) | |
Segmentation working!
Diffstat (limited to 'src/text_recognizer/models/base.py')
| -rw-r--r-- | src/text_recognizer/models/base.py | 11 | 
1 files changed, 6 insertions, 5 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d394b4c..f2cd4b8 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -159,7 +159,7 @@ class Model(ABC):              self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])              self.test_dataset.load_or_generate_data() -            # Set the flag to true to disable ability to load data agian. +            # Set the flag to true to disable ability to load data again.              self.data_prepared = True      def train_dataloader(self) -> DataLoader: @@ -260,7 +260,7 @@ class Model(ABC):      @property      def mapping(self) -> Dict:          """Returns the mapping between network output and Emnist character.""" -        return self._mapper.mapping +        return self._mapper.mapping if self._mapper is not None else None      def eval(self) -> None:          """Sets the network to evaluation mode.""" @@ -341,7 +341,7 @@ class Model(ABC):          if input_shape is not None:              summary(self.network, input_shape, depth=depth, device=device)          elif self._input_shape is not None: -            input_shape = (1,) + tuple(self._input_shape) +            input_shape = tuple(self._input_shape)              summary(self.network, input_shape, depth=depth, device=device)          else:              logger.warning("Could not print summary as input shape is not set.") @@ -427,7 +427,7 @@ class Model(ABC):              )              shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) -    def load_weights(self, network_fn: Type[nn.Module]) -> None: +    def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:          """Load the network weights."""          logger.debug("Loading network with pretrained weights.")          filename = glob(self.weights_filename)[0] @@ -441,7 +441,8 @@ class Model(ABC):          weights = state_dict["model_state"]          # Initializes the network with trained weights. -        self._network = network_fn(**self._network_args) +        if network_fn is not None: +            self._network = network_fn(**self._network_args)          self._network.load_state_dict(weights)          if "swa_network" in state_dict:  |