summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8b68ed9..94dbde5 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,12 +49,14 @@ class BaseLitModel(LightningModule):
"""Configures the optimizer."""
optimizers = []
for optimizer_config in self.optimizer_configs.values():
- network = getattr(self, optimizer_config.parameters)
+ module = self
+ for m in str(optimizer_config.parameters).split("."):
+ module = getattr(module, m)
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
hydra.utils.instantiate(
- self.optimizer_config, params=network.parameters()
+ optimizer_config, params=module.parameters()
)
)
return optimizers
@@ -92,7 +94,7 @@ class BaseLitModel(LightningModule):
) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
optimizers = self._configure_optimizer()
- schedulers = self._configure_lr_scheduler(optimizers)
+ schedulers = self._configure_lr_schedulers(optimizers)
return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor: