summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:57:44 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:57:44 +0200
commit09bb189ba19304d26ede208fe43c3c882c309d7f (patch)
treeced135e1b644cd7f93837ade31e9427dc8f11533 /text_recognizer/models
parentc480d7bc1459a591f5647f62277f5be4e02b1ce6 (diff)
Fix lr sheduler loading
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f8f4b40..4587a30 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -55,15 +55,13 @@ class LitBase(LightningModule):
def _configure_lr_schedulers(
self, optimizer: Type[torch.optim.Optimizer]
- ) -> Dict[str, Any]:
+ ) -> Optional[Dict[str, Any]]:
"""Configures the lr scheduler."""
log.info(
f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
)
- monitor = self.lr_scheduler_config.monitor
- interval = self.lr_scheduler_config.interval
- del self.lr_scheduler_config.monitor
- del self.lr_scheduler_config.interval
+ monitor = self.lr_scheduler_config.pop("monitor")
+ interval = self.lr_scheduler_config.pop("interval")
return {
"monitor": monitor,
@@ -78,8 +76,10 @@ class LitBase(LightningModule):
) -> Dict[str, Any]:
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
- scheduler = self._configure_lr_schedulers(optimizer)
- return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ if self.lr_scheduler_config is not None:
+ scheduler = self._configure_lr_schedulers(optimizer)
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ return {"optimizer": optimizer}
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""