summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:12:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:12:20 +0200
commit38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (patch)
treeaaa3f56495cdfbcc5f1434485fb237dfd6cf34a2 /text_recognizer/models/base.py
parentbef106191e20b42741984c407dc4884ab1ee49eb (diff)
Add OmegaConf for configs
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py55
1 files changed, 33 insertions, 22 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 2d6e435..1004f48 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,33 +1,32 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Union, Tuple, Type
import madgrad
+from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer import networks
-
class LitBaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
self,
- network_args: Dict,
- optimizer_args: Dict,
- lr_scheduler_args: Dict,
- criterion_args: Dict,
+ network: Type[nn,Module],
+ optimizer: Union[OmegaConf, Dict],
+ lr_scheduler: Union[OmegaConf, Dict],
+ criterion: Union[OmegaConf, Dict],
monitor: str = "val_loss",
) -> None:
super().__init__()
self.monitor = monitor
- self.network = getattr(networks, network_args["type"])(**network_args["args"])
- self.optimizer_args = optimizer_args
- self.lr_scheduler_args = lr_scheduler_args
- self.loss_fn = self.configure_criterion(criterion_args)
+ self.network = network
+ self._optimizer = OmegaConf.create(optimizer)
+ self._lr_scheduler = OmegaConf.create(lr_scheduler)
+ self.loss_fn = self.configure_criterion(criterion)
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
@@ -35,27 +34,39 @@ class LitBaseModel(pl.LightningModule):
self.test_acc = torchmetrics.Accuracy()
@staticmethod
- def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
+ def configure_criterion(criterion: Union[OmegaConf, Dict]) -> Type[nn.Module]:
"""Returns a loss functions."""
- args = {} or criterion_args["args"]
- return getattr(nn, criterion_args["type"])(**args)
+ criterion = OmegaConf.create(criterion)
+ args = {} or criterion.args
+ return getattr(nn, criterion.type)(**args)
- def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
- """Configures optimizer and lr scheduler."""
- args = {} or self.optimizer_args["args"]
- if self.optimizer_args["type"] == "MADGRAD":
- optimizer = getattr(madgrad, self.optimizer_args["type"])(**args)
+ def _configure_optimizer(self) -> type:
+ """Configures the optimizer."""
+ args = {} or self._optimizer.args
+ if self._optimizer.type == "MADGRAD":
+ optimizer_class = madgrad.MADGRAD
else:
- optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
+ optimizer_class = getattr(torch.optim, self._optimizer.type)
+ return optimizer_class(parameters=self.parameters(), **args)
+ def _configure_lr_scheduler(self) -> Dict[str, Any]:
+ """Configures the lr scheduler."""
scheduler = {"monitor": self.monitor}
- args = {} or self.lr_scheduler_args["args"]
+ args = {} or self._lr_scheduler.args
+
if "interval" in args:
scheduler["interval"] = args.pop("interval")
scheduler["scheduler"] = getattr(
- torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+ torch.optim.lr_scheduler, self._lr_scheduler.type
)(**args)
+ return scheduler
+
+ def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+ """Configures optimizer and lr scheduler."""
+ optimizer = self._configure_optimizer()
+ scheduler = self._configure_lr_scheduler()
+
return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor: