summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/base.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f8f4b40..4587a30 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -55,15 +55,13 @@ class LitBase(LightningModule):
def _configure_lr_schedulers(
self, optimizer: Type[torch.optim.Optimizer]
- ) -> Dict[str, Any]:
+ ) -> Optional[Dict[str, Any]]:
"""Configures the lr scheduler."""
log.info(
f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
)
- monitor = self.lr_scheduler_config.monitor
- interval = self.lr_scheduler_config.interval
- del self.lr_scheduler_config.monitor
- del self.lr_scheduler_config.interval
+ monitor = self.lr_scheduler_config.pop("monitor")
+ interval = self.lr_scheduler_config.pop("interval")
return {
"monitor": monitor,
@@ -78,8 +76,10 @@ class LitBase(LightningModule):
) -> Dict[str, Any]:
"""Configures optimizer and lr scheduler."""
optimizer = self._configure_optimizer()
- scheduler = self._configure_lr_schedulers(optimizer)
- return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ if self.lr_scheduler_config is not None:
+ scheduler = self._configure_lr_schedulers(optimizer)
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ return {"optimizer": optimizer}
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""