summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py81
1 files changed, 48 insertions, 33 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index ab3fa35..8b68ed9 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -24,8 +24,8 @@ class BaseLitModel(LightningModule):
network: Type[nn.Module] = attr.ib()
mapping: Type[AbstractMapping] = attr.ib()
loss_fn: Type[nn.Module] = attr.ib()
- optimizer_config: DictConfig = attr.ib()
- lr_scheduler_config: DictConfig = attr.ib()
+ optimizer_configs: DictConfig = attr.ib()
+ lr_scheduler_configs: DictConfig = attr.ib()
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -45,40 +45,55 @@ class BaseLitModel(LightningModule):
) -> None:
optimizer.zero_grad(set_to_none=True)
- def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
+ def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
"""Configures the optimizer."""
- log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(
- self.optimizer_config, params=self.network.parameters()
- )
-
- def _configure_lr_scheduler(
- self, optimizer: Type[torch.optim.Optimizer]
- ) -> Dict[str, Any]:
+ optimizers = []
+ for optimizer_config in self.optimizer_configs.values():
+ network = getattr(self, optimizer_config.parameters)
+ del optimizer_config.parameters
+ log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
+ optimizers.append(
+ hydra.utils.instantiate(
+ self.optimizer_config, params=network.parameters()
+ )
+ )
+ return optimizers
+
+ def _configure_lr_schedulers(
+ self, optimizers: List[Type[torch.optim.Optimizer]]
+ ) -> List[Dict[str, Any]]:
"""Configures the lr scheduler."""
- # Extract non-class arguments.
- monitor = self.lr_scheduler_config.monitor
- interval = self.lr_scheduler_config.interval
- del self.lr_scheduler_config.monitor
- del self.lr_scheduler_config.interval
-
- log.info(
- f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
- )
- scheduler = {
- "monitor": monitor,
- "interval": interval,
- "scheduler": hydra.utils.instantiate(
- self.lr_scheduler_config, optimizer=optimizer
- ),
- }
- return scheduler
-
- def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+ schedulers = []
+ for optimizer, lr_scheduler_config in zip(
+ optimizers, self.lr_scheduler_configs.values()
+ ):
+ # Extract non-class arguments.
+ monitor = lr_scheduler_config.monitor
+ interval = lr_scheduler_config.interval
+ del lr_scheduler_config.monitor
+ del lr_scheduler_config.interval
+
+ log.info(
+ f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
+ )
+ scheduler = {
+ "monitor": monitor,
+ "interval": interval,
+ "scheduler": hydra.utils.instantiate(
+ lr_scheduler_config, optimizer=optimizer
+ ),
+ }
+ schedulers.append(scheduler)
+
+ return schedulers
+
+ def configure_optimizers(
+ self,
+ ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
- optimizer = self._configure_optimizer()
- scheduler = self._configure_lr_scheduler(optimizer)
- return [optimizer], [scheduler]
+ optimizers = self._configure_optimizer()
+ schedulers = self._configure_lr_scheduler(optimizers)
+ return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""