diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-12 23:42:03 +0100 |
commit | 8fdb6435e15703fa5b76df19728d905650ee1aef (patch) | |
tree | be3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/models/base.py | |
parent | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff) | |
parent | 6cb08a110620ee09fe9d8a5d008197a801d025df (diff) |
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index cc44c92..a945b41 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class Model(ABC): network_args (Optional[Dict]): Arguments for the network. Defaults to None. dataset_args (Optional[Dict]): Arguments for the dataset. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. + criterion (Optional[Callable]): The criterion to evaluate the performance of the network. Defaults to None. criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. @@ -221,7 +221,7 @@ class Model(ABC): def _configure_network(self, network_fn: Type[nn.Module]) -> None: """Loads the network.""" - # If no network arguemnts are given, load pretrained weights if they exist. + # If no network arguments are given, load pretrained weights if they exist. if self._network_args is None: self.load_weights(network_fn) else: @@ -245,7 +245,7 @@ class Model(ABC): self._optimizer = None if self._optimizer and self._lr_scheduler is not None: - if "OneCycleLR" in str(self._lr_scheduler): + if "steps_per_epoch" in self.lr_scheduler_args: self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) # Assume lr scheduler should update at each epoch if not specified. @@ -412,7 +412,7 @@ class Model(ABC): self._optimizer.load_state_dict(checkpoint["optimizer_state"]) if self._lr_scheduler is not None: - # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs + # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": self._lr_scheduler["lr_scheduler"].load_state_dict( |