diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 81 |
1 files changed, 48 insertions, 33 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index ab3fa35..8b68ed9 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -24,8 +24,8 @@ class BaseLitModel(LightningModule): network: Type[nn.Module] = attr.ib() mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() - optimizer_config: DictConfig = attr.ib() - lr_scheduler_config: DictConfig = attr.ib() + optimizer_configs: DictConfig = attr.ib() + lr_scheduler_configs: DictConfig = attr.ib() train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -45,40 +45,55 @@ class BaseLitModel(LightningModule): ) -> None: optimizer.zero_grad(set_to_none=True) - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]: """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_scheduler( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + optimizers = [] + for optimizer_config in self.optimizer_configs.values(): + network = getattr(self, optimizer_config.parameters) + del optimizer_config.parameters + log.info(f"Instantiating optimizer <{optimizer_config._target_}>") + optimizers.append( + hydra.utils.instantiate( + self.optimizer_config, params=network.parameters() + ) + ) + return optimizers + + def _configure_lr_schedulers( + self, optimizers: List[Type[torch.optim.Optimizer]] + ) -> List[Dict[str, Any]]: """Configures the lr scheduler.""" - # Extract non-class arguments. - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval - - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - scheduler = { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - return scheduler - - def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + schedulers = [] + for optimizer, lr_scheduler_config in zip( + optimizers, self.lr_scheduler_configs.values() + ): + # Extract non-class arguments. + monitor = lr_scheduler_config.monitor + interval = lr_scheduler_config.interval + del lr_scheduler_config.monitor + del lr_scheduler_config.interval + + log.info( + f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" + ) + scheduler = { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + lr_scheduler_config, optimizer=optimizer + ), + } + schedulers.append(scheduler) + + return schedulers + + def configure_optimizers( + self, + ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] + optimizers = self._configure_optimizer() + schedulers = self._configure_lr_scheduler(optimizers) + return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" |