summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/base.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 11:47:24 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 11:47:24 +0200
commit1b3b8073a19f939d18a0bb85247eb0d99284f7cc (patch)
treee74e78230ebb179237c063fecf0b52458ce3aa3e /src/text_recognizer/models/base.py
parent6137f43c910946301279825e50759a9dd76c6131 (diff)
Bash scripts and some bug fixes.
Diffstat (limited to 'src/text_recognizer/models/base.py')
-rw-r--r--src/text_recognizer/models/base.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index caf8065..e89b670 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -356,7 +356,8 @@ class Model(ABC):
state["optimizer_state"] = self._optimizer.state_dict()
if self._lr_scheduler is not None:
- state["scheduler_state"] = self._lr_scheduler.state_dict()
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
if self._swa_network is not None:
state["swa_network"] = self._swa_network.state_dict()
@@ -383,8 +384,11 @@ class Model(ABC):
if self._lr_scheduler is not None:
# Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs
# with OneCycleLR.
- if self._lr_scheduler.__class__.__name__ != "OneCycleLR":
- self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"])
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
if self._swa_network is not None:
self._swa_network.load_state_dict(checkpoint["swa_network"])