diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 46e5136..2d6e435 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,5 +1,5 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import madgrad import pytorch_lightning as pl @@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion_args["args"] return getattr(nn, criterion_args["type"])(**args) - def configure_optimizer(self) -> Dict[str, Any]: + def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" args = {} or self.optimizer_args["args"] if self.optimizer_args["type"] == "MADGRAD": @@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule): else: optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) + scheduler = {"monitor": self.monitor} args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( - **args - ) - return { - "optimizer": optimizer, - "lr_scheduler": scheduler, - "monitor": self.monitor, - } + if "interval" in args: + scheduler["interval"] = args.pop("interval") + + scheduler["scheduler"] = getattr( + torch.optim.lr_scheduler, self.lr_scheduler_args["type"] + )(**args) + return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" |