diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/models/base.py | 10 | 
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 94dbde5..56d4ca5 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@  """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type  import attr  import hydra @@ -25,7 +25,7 @@ class BaseLitModel(LightningModule):      mapping: Type[AbstractMapping] = attr.ib()      loss_fn: Type[nn.Module] = attr.ib()      optimizer_configs: DictConfig = attr.ib() -    lr_scheduler_configs: DictConfig = attr.ib() +    lr_scheduler_configs: Optional[DictConfig] = attr.ib()      train_acc: torchmetrics.Accuracy = attr.ib(          init=False, default=torchmetrics.Accuracy()      ) @@ -55,9 +55,7 @@ class BaseLitModel(LightningModule):              del optimizer_config.parameters              log.info(f"Instantiating optimizer <{optimizer_config._target_}>")              optimizers.append( -                hydra.utils.instantiate( -                    optimizer_config, params=module.parameters() -                ) +                hydra.utils.instantiate(optimizer_config, params=module.parameters())              )          return optimizers @@ -65,6 +63,8 @@ class BaseLitModel(LightningModule):          self, optimizers: List[Type[torch.optim.Optimizer]]      ) -> List[Dict[str, Any]]:          """Configures the lr scheduler.""" +        if None in self.lr_scheduler_configs: +            return []          schedulers = []          for optimizer, lr_scheduler_config in zip(              optimizers, self.lr_scheduler_configs.values()  |