summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 94dbde5..56d4ca5 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
import attr
import hydra
@@ -25,7 +25,7 @@ class BaseLitModel(LightningModule):
mapping: Type[AbstractMapping] = attr.ib()
loss_fn: Type[nn.Module] = attr.ib()
optimizer_configs: DictConfig = attr.ib()
- lr_scheduler_configs: DictConfig = attr.ib()
+ lr_scheduler_configs: Optional[DictConfig] = attr.ib()
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -55,9 +55,7 @@ class BaseLitModel(LightningModule):
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
- hydra.utils.instantiate(
- optimizer_config, params=module.parameters()
- )
+ hydra.utils.instantiate(optimizer_config, params=module.parameters())
)
return optimizers
@@ -65,6 +63,8 @@ class BaseLitModel(LightningModule):
self, optimizers: List[Type[torch.optim.Optimizer]]
) -> List[Dict[str, Any]]:
"""Configures the lr scheduler."""
+ if None in self.lr_scheduler_configs:
+ return []
schedulers = []
for optimizer, lr_scheduler_config in zip(
optimizers, self.lr_scheduler_configs.values()