summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/text_recognizer/models/base.py
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index cc44c92..a945b41 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -49,7 +49,7 @@ class Model(ABC):
network_args (Optional[Dict]): Arguments for the network. Defaults to None.
dataset_args (Optional[Dict]): Arguments for the dataset.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
- criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
Defaults to None.
criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
@@ -221,7 +221,7 @@ class Model(ABC):
def _configure_network(self, network_fn: Type[nn.Module]) -> None:
"""Loads the network."""
- # If no network arguemnts are given, load pretrained weights if they exist.
+ # If no network arguments are given, load pretrained weights if they exist.
if self._network_args is None:
self.load_weights(network_fn)
else:
@@ -245,7 +245,7 @@ class Model(ABC):
self._optimizer = None
if self._optimizer and self._lr_scheduler is not None:
- if "OneCycleLR" in str(self._lr_scheduler):
+ if "steps_per_epoch" in self.lr_scheduler_args:
self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
# Assume lr scheduler should update at each epoch if not specified.
@@ -412,7 +412,7 @@ class Model(ABC):
self._optimizer.load_state_dict(checkpoint["optimizer_state"])
if self._lr_scheduler is not None:
- # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
self._lr_scheduler["lr_scheduler"].load_state_dict(