diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:12:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-05 23:12:20 +0200 |
commit | 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 (patch) | |
tree | aaa3f56495cdfbcc5f1434485fb237dfd6cf34a2 /text_recognizer/models | |
parent | bef106191e20b42741984c407dc4884ab1ee49eb (diff) |
Add OmegaConf for configs
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 55 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 14 |
2 files changed, 41 insertions, 28 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 2d6e435..1004f48 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,33 +1,32 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Union, Tuple, Type import madgrad +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn from torch import Tensor import torchmetrics -from text_recognizer import networks - class LitBaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", ) -> None: super().__init__() self.monitor = monitor - self.network = getattr(networks, network_args["type"])(**network_args["args"]) - self.optimizer_args = optimizer_args - self.lr_scheduler_args = lr_scheduler_args - self.loss_fn = self.configure_criterion(criterion_args) + self.network = network + self._optimizer = OmegaConf.create(optimizer) + self._lr_scheduler = OmegaConf.create(lr_scheduler) + self.loss_fn = self.configure_criterion(criterion) # Accuracy metric self.train_acc = torchmetrics.Accuracy() @@ -35,27 +34,39 @@ class LitBaseModel(pl.LightningModule): self.test_acc = torchmetrics.Accuracy() @staticmethod - def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: + def configure_criterion(criterion: Union[OmegaConf, Dict]) -> Type[nn.Module]: """Returns a loss functions.""" - args = {} or criterion_args["args"] - return getattr(nn, criterion_args["type"])(**args) + criterion = OmegaConf.create(criterion) + args = {} or criterion.args + return getattr(nn, criterion.type)(**args) - def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]: - """Configures optimizer and lr scheduler.""" - args = {} or self.optimizer_args["args"] - if self.optimizer_args["type"] == "MADGRAD": - optimizer = getattr(madgrad, self.optimizer_args["type"])(**args) + def _configure_optimizer(self) -> type: + """Configures the optimizer.""" + args = {} or self._optimizer.args + if self._optimizer.type == "MADGRAD": + optimizer_class = madgrad.MADGRAD else: - optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) + optimizer_class = getattr(torch.optim, self._optimizer.type) + return optimizer_class(parameters=self.parameters(), **args) + def _configure_lr_scheduler(self) -> Dict[str, Any]: + """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} - args = {} or self.lr_scheduler_args["args"] + args = {} or self._lr_scheduler.args + if "interval" in args: scheduler["interval"] = args.pop("interval") scheduler["scheduler"] = getattr( - torch.optim.lr_scheduler, self.lr_scheduler_args["type"] + torch.optim.lr_scheduler, self._lr_scheduler.type )(**args) + return scheduler + + def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + """Configures optimizer and lr scheduler.""" + optimizer = self._configure_optimizer() + scheduler = self._configure_lr_scheduler() + return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 285b715..3625ab2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,7 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Union, Tuple +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn @@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + network, optimizer, lr_scheduler, criterion, monitor ) self.mapping, ignore_tokens = self.configure_mapping(mapping) @@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel): @staticmethod def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" + # TODO: Fix me!!! mapping, inverse_mapping, _ = emnist_mapping() start_index = inverse_mapping["<s>"] end_index = inverse_mapping["<e>"] |