diff options
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r-- | src/text_recognizer/models/base.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..e89b670 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -356,7 +356,8 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() @@ -383,8 +384,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) |