diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-02 02:57:44 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-10-02 02:57:44 +0200 |
commit | 09bb189ba19304d26ede208fe43c3c882c309d7f (patch) | |
tree | ced135e1b644cd7f93837ade31e9427dc8f11533 | |
parent | c480d7bc1459a591f5647f62277f5be4e02b1ce6 (diff) |
Fix lr sheduler loading
-rw-r--r-- | text_recognizer/models/base.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f8f4b40..4587a30 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -55,15 +55,13 @@ class LitBase(LightningModule): def _configure_lr_schedulers( self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """Configures the lr scheduler.""" log.info( f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" ) - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") return { "monitor": monitor, @@ -78,8 +76,10 @@ class LitBase(LightningModule): ) -> Dict[str, Any]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_schedulers(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" |