diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
commit | e181195a699d7fa237f256d90ab4dedffc03d405 (patch) | |
tree | 6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/text_recognizer/models | |
parent | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff) |
Minor bug fixes etc.
Diffstat (limited to 'src/text_recognizer/models')
-rw-r--r-- | src/text_recognizer/models/base.py | 56 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/models/line_ctc_model.py | 8 |
3 files changed, 39 insertions, 31 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d23fe56..caf8065 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -77,9 +77,9 @@ class Model(ABC): # Stochastic Weight Averaging placeholders. self.swa_args = swa_args - self._swa_start = None self._swa_scheduler = None self._swa_network = None + self._use_swa_model = False # Experiment directory. self.model_dir = None @@ -220,15 +220,24 @@ class Model(ABC): if self._optimizer and self._lr_scheduler is not None: if "OneCycleLR" in str(self._lr_scheduler): self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - self._lr_scheduler = self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ) - else: - self._lr_scheduler = None + + # Assume lr scheduler should update at each epoch if not specified. + if "interval" not in self.lr_scheduler_args: + interval = "epoch" + else: + interval = self.lr_scheduler_args.pop("interval") + self._lr_scheduler = { + "lr_scheduler": self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ), + "interval": interval, + } if self.swa_args is not None: - self._swa_start = self.swa_args["start"] - self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_scheduler = { + "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), + "swa_start": self.swa_args["start"], + } self._swa_network = AveragedModel(self._network).to(self.device) @property @@ -280,21 +289,16 @@ class Model(ABC): return self._optimizer @property - def lr_scheduler(self) -> Optional[Callable]: - """Learning rate scheduler.""" + def lr_scheduler(self) -> Optional[Dict]: + """Returns a directory with the learning rate scheduler.""" return self._lr_scheduler @property - def swa_scheduler(self) -> Optional[Callable]: - """Returns the stochastic weight averaging scheduler.""" + def swa_scheduler(self) -> Optional[Dict]: + """Returns a directory with the stochastic weight averaging scheduler.""" return self._swa_scheduler @property - def swa_start(self) -> Optional[Callable]: - """Returns the start epoch of stochastic weight averaging.""" - return self._swa_start - - @property def swa_network(self) -> Optional[Callable]: """Returns the stochastic weight averaging network.""" return self._swa_network @@ -311,20 +315,32 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") + def use_swa_model(self) -> None: + """Set to use predictions from SWA model.""" + if self.swa_network is not None: + self._use_swa_model = True + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass with the network.""" + if self._use_swa_model: + return self.swa_network(x) + else: + return self.network(x) + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: """Compute the loss.""" return self.criterion(output, targets) def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 ) -> None: """Prints a summary of the network architecture.""" if input_shape is not None: - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 64ba693..50e94a2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -75,11 +75,7 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + logits = self.forward(image) prediction = self.softmax(logits.squeeze(0)) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index af41f18..16eaed3 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -98,16 +98,12 @@ class LineCTCModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - log_probs = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + log_probs = self.forward(image) raw_pred, _ = greedy_decoder( predictions=log_probs, character_mapper=self.mapper, - blank_label=80, + blank_label=79, collapse_repeated=True, ) |