diff options
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cc44c92..a945b41 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class Model(ABC): network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. + criterion (Optional[Callable]): The criterion to evaluate the performance of the network. Defaults to None. criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. @@ -221,7 +221,7 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" - # If no network arguemnts are given, load pretrained weights if they exist. + # If no network arguments are given, load pretrained weights if they exist. if self._network_args is None: self.load_weights(network_fn) else: @@ -245,7 +245,7 @@ class Model(ABC): self._optimizer = None if self._optimizer and self._lr_scheduler is not None: - if "OneCycleLR" in str(self._lr_scheduler): + if "steps_per_epoch" in self.lr_scheduler_args: self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) # Assume lr scheduler should update at each epoch if not specified. @@ -412,7 +412,7 @@ class Model(ABC): self._optimizer.load_state_dict(checkpoint["optimizer_state"]) if self._lr_scheduler is not None: - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": self._lr_scheduler["lr_scheduler"].load_state_dict( |