From 1b3b8073a19f939d18a0bb85247eb0d99284f7cc Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 20 Sep 2020 11:47:24 +0200 Subject: Bash scripts and some bug fixes. --- src/text_recognizer/models/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'src/text_recognizer/models') diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..e89b670 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -356,7 +356,8 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() @@ -383,8 +384,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) -- cgit v1.2.3-70-g09d2