diff options
-rw-r--r-- | text_recognizer/models/base.py | 70 |
1 files changed, 26 insertions, 44 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 26cc18c..1ebb256 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type import hydra from loguru import logger as log @@ -47,57 +47,39 @@ class LitBase(LightningModule): """Optimal way to set grads to zero.""" optimizer.zero_grad(set_to_none=True) - def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]: + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" - optimizers = [] - for optimizer_config in self.optimizer_configs.values(): - module = self - for m in str(optimizer_config.parameters).split("."): - module = getattr(module, m) - del optimizer_config.parameters - log.info(f"Instantiating optimizer <{optimizer_config._target_}>") - optimizers.append( - hydra.utils.instantiate(optimizer_config, params=module.parameters()) - ) - return optimizers + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) def _configure_lr_schedulers( - self, optimizers: List[Type[torch.optim.Optimizer]] - ) -> List[Dict[str, Any]]: + self, optimizer: Type[torch.optim.Optimizer] + ) -> Dict[str, Any]: """Configures the lr scheduler.""" - if self.lr_scheduler_configs is None: - return [] - schedulers = [] - for optimizer, lr_scheduler_config in zip( - optimizers, self.lr_scheduler_configs.values() - ): - # Extract non-class arguments. - monitor = lr_scheduler_config.monitor - interval = lr_scheduler_config.interval - del lr_scheduler_config.monitor - del lr_scheduler_config.interval - - log.info( - f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" - ) - scheduler = { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - lr_scheduler_config, optimizer=optimizer - ), - } - schedulers.append(scheduler) - - return schedulers + log.info( + f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" + ) + monitor = self.lr_scheduler_config.monitor + interval = self.lr_scheduler_config.interval + del self.lr_scheduler_config.monitor + del self.lr_scheduler_config.interval + + return { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ), + } def configure_optimizers( self, - ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: + ) -> Dict[str, Any]: """Configures optimizer and lr scheduler.""" - optimizers = self._configure_optimizer() - schedulers = self._configure_lr_schedulers(optimizers) - return optimizers, schedulers + optimizer = self._configure_optimizer() + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "scheduler": scheduler} def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" |