diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f8f4b40..4587a30 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -55,15 +55,13 @@ class LitBase(LightningModule): def _configure_lr_schedulers( self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """Configures the lr scheduler.""" log.info( f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" ) - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") return { "monitor": monitor, @@ -78,8 +76,10 @@ class LitBase(LightningModule): ) -> Dict[str, Any]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_schedulers(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" |