summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-18 01:00:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-18 01:00:19 +0200
commitb22bd31b5df62b2d17bc060d35d73cfae95851af (patch)
treefa7f53e45b727a3e829f9bd813b2260835113f30 /text_recognizer/models
parent062eebf0d690365cf7d9f6019d147ea195cc3a63 (diff)
Refator loading of scheduler and optim
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py70
1 files changed, 26 insertions, 44 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 26cc18c..1ebb256 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, Optional, Tuple, Type
import hydra
from loguru import logger as log
@@ -47,57 +47,39 @@ class LitBase(LightningModule):
"""Optimal way to set grads to zero."""
optimizer.zero_grad(set_to_none=True)
- def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
+ def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
- optimizers = []
- for optimizer_config in self.optimizer_configs.values():
- module = self
- for m in str(optimizer_config.parameters).split("."):
- module = getattr(module, m)
- del optimizer_config.parameters
- log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
- optimizers.append(
- hydra.utils.instantiate(optimizer_config, params=module.parameters())
- )
- return optimizers
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
def _configure_lr_schedulers(
- self, optimizers: List[Type[torch.optim.Optimizer]]
- ) -> List[Dict[str, Any]]:
+ self, optimizer: Type[torch.optim.Optimizer]
+ ) -> Dict[str, Any]:
"""Configures the lr scheduler."""
- if self.lr_scheduler_configs is None:
- return []
- schedulers = []
- for optimizer, lr_scheduler_config in zip(
- optimizers, self.lr_scheduler_configs.values()
- ):
- # Extract non-class arguments.
- monitor = lr_scheduler_config.monitor
- interval = lr_scheduler_config.interval
- del lr_scheduler_config.monitor
- del lr_scheduler_config.interval
-
- log.info(
- f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
- )
- scheduler = {
- "monitor": monitor,
- "interval": interval,
- "scheduler": hydra.utils.instantiate(
- lr_scheduler_config, optimizer=optimizer
- ),
- }
- schedulers.append(scheduler)
-
- return schedulers
+ log.info(
+ f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
+ )
+ monitor = self.lr_scheduler_config.monitor
+ interval = self.lr_scheduler_config.interval
+ del self.lr_scheduler_config.monitor
+ del self.lr_scheduler_config.interval
+
+ return {
+ "monitor": monitor,
+ "interval": interval,
+ "scheduler": hydra.utils.instantiate(
+ self.lr_scheduler_config, optimizer=optimizer
+ ),
+ }
def configure_optimizers(
self,
- ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
+ ) -> Dict[str, Any]:
"""Configures optimizer and lr scheduler."""
- optimizers = self._configure_optimizer()
- schedulers = self._configure_lr_schedulers(optimizers)
- return optimizers, schedulers
+ optimizer = self._configure_optimizer()
+ scheduler = self._configure_lr_schedulers(optimizer)
+ return {"optimizer": optimizer, "scheduler": scheduler}
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""