From 09bb189ba19304d26ede208fe43c3c882c309d7f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 02:57:44 +0200 Subject: Fix lr sheduler loading --- text_recognizer/models/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f8f4b40..4587a30 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -55,15 +55,13 @@ class LitBase(LightningModule): def _configure_lr_schedulers( self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """Configures the lr scheduler.""" log.info( f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" ) - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") return { "monitor": monitor, @@ -78,8 +76,10 @@ class LitBase(LightningModule): ) -> Dict[str, Any]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_schedulers(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" -- cgit v1.2.3-70-g09d2