diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 94dbde5..56d4ca5 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import attr import hydra @@ -25,7 +25,7 @@ class BaseLitModel(LightningModule): mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() optimizer_configs: DictConfig = attr.ib() - lr_scheduler_configs: DictConfig = attr.ib() + lr_scheduler_configs: Optional[DictConfig] = attr.ib() train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -55,9 +55,7 @@ class BaseLitModel(LightningModule): del optimizer_config.parameters log.info(f"Instantiating optimizer <{optimizer_config._target_}>") optimizers.append( - hydra.utils.instantiate( - optimizer_config, params=module.parameters() - ) + hydra.utils.instantiate(optimizer_config, params=module.parameters()) ) return optimizers @@ -65,6 +63,8 @@ class BaseLitModel(LightningModule): self, optimizers: List[Type[torch.optim.Optimizer]] ) -> List[Dict[str, Any]]: """Configures the lr scheduler.""" + if None in self.lr_scheduler_configs: + return [] schedulers = [] for optimizer, lr_scheduler_config in zip( optimizers, self.lr_scheduler_configs.values() |