summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cc44c92..a945b41 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class Model(ABC):
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
Defaults to None.
criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
@@ -221,7 +221,7 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
- # If no network arguemnts are given, load pretrained weights if they exist.
+ # If no network arguments are given, load pretrained weights if they exist.
if self._network_args is None:
self.load_weights(network_fn)
else:
@@ -245,7 +245,7 @@ class Model(ABC):
self._optimizer = None
if self._optimizer and self._lr_scheduler is not None:
- if "OneCycleLR" in str(self._lr_scheduler):
+ if "steps_per_epoch" in self.lr_scheduler_args:
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
# Assume lr scheduler should update at each epoch if not specified.
@@ -412,7 +412,7 @@ class Model(ABC):
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
if self._lr_scheduler is not None:
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
self._lr_scheduler["lr_scheduler"].load_state_dict(