From 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Apr 2021 23:12:20 +0200 Subject: Add OmegaConf for configs --- text_recognizer/models/base.py | 55 ++++++++++++++++----------- text_recognizer/models/transformer.py | 14 ++++--- text_recognizer/networks/image_transformer.py | 42 ++++++++++++-------- 3 files changed, 68 insertions(+), 43 deletions(-) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 2d6e435..1004f48 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -1,33 +1,32 @@ """Base PyTorch Lightning model.""" -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Union, Tuple, Type import madgrad +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn from torch import Tensor import torchmetrics -from text_recognizer import networks - class LitBaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", ) -> None: super().__init__() self.monitor = monitor - self.network = getattr(networks, network_args["type"])(**network_args["args"]) - self.optimizer_args = optimizer_args - self.lr_scheduler_args = lr_scheduler_args - self.loss_fn = self.configure_criterion(criterion_args) + self.network = network + self._optimizer = OmegaConf.create(optimizer) + self._lr_scheduler = OmegaConf.create(lr_scheduler) + self.loss_fn = self.configure_criterion(criterion) # Accuracy metric self.train_acc = torchmetrics.Accuracy() @@ -35,27 +34,39 @@ class LitBaseModel(pl.LightningModule): self.test_acc = torchmetrics.Accuracy() @staticmethod - def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: + def configure_criterion(criterion: Union[OmegaConf, Dict]) -> Type[nn.Module]: """Returns a loss functions.""" - args = {} or criterion_args["args"] - return getattr(nn, criterion_args["type"])(**args) + criterion = OmegaConf.create(criterion) + args = {} or criterion.args + return getattr(nn, criterion.type)(**args) - def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]: - """Configures optimizer and lr scheduler.""" - args = {} or self.optimizer_args["args"] - if self.optimizer_args["type"] == "MADGRAD": - optimizer = getattr(madgrad, self.optimizer_args["type"])(**args) + def _configure_optimizer(self) -> type: + """Configures the optimizer.""" + args = {} or self._optimizer.args + if self._optimizer.type == "MADGRAD": + optimizer_class = madgrad.MADGRAD else: - optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) + optimizer_class = getattr(torch.optim, self._optimizer.type) + return optimizer_class(parameters=self.parameters(), **args) + def _configure_lr_scheduler(self) -> Dict[str, Any]: + """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} - args = {} or self.lr_scheduler_args["args"] + args = {} or self._lr_scheduler.args + if "interval" in args: scheduler["interval"] = args.pop("interval") scheduler["scheduler"] = getattr( - torch.optim.lr_scheduler, self.lr_scheduler_args["type"] + torch.optim.lr_scheduler, self._lr_scheduler.type )(**args) + return scheduler + + def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + """Configures optimizer and lr scheduler.""" + optimizer = self._configure_optimizer() + scheduler = self._configure_lr_scheduler() + return [optimizer], [scheduler] def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 285b715..3625ab2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,7 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Union, Tuple +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn @@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + network, optimizer, lr_scheduler, criterion, monitor ) self.mapping, ignore_tokens = self.configure_mapping(mapping) @@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel): @staticmethod def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" + # TODO: Fix me!!! mapping, inverse_mapping, _ = emnist_mapping() start_index = inverse_mapping[""] end_index = inverse_mapping[""] diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index b9254c9..aa024e0 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -8,10 +8,12 @@ together with the target tokens. TODO: Local attention for transformer.j """ +import importlib import math -from typing import Any, Dict, List, Optional, Sequence, Type +from typing import Dict, List, Union, Sequence, Tuple, Type from einops import rearrange +from omegaconf import OmegaConf import torch from torch import nn from torch import Tensor @@ -32,8 +34,8 @@ class ImageTransformer(nn.Module): self, input_shape: Sequence[int], output_shape: Sequence[int], - backbone: Type[nn.Module], - mapping: Optional[List[str]] = None, + encoder: Union[OmegaConf, Dict], + mapping: str, num_decoder_layers: int = 4, hidden_dim: int = 256, num_heads: int = 4, @@ -51,8 +53,8 @@ class ImageTransformer(nn.Module): self.pad_index = inverse_mapping["

"] # Image backbone - self.backbone = backbone - self.latent_encoding = PositionalEncoding2D( + self.encoder = self._configure_encoder(encoder) + self.feature_map_encoding = PositionalEncoding2D( hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] ) @@ -86,20 +88,30 @@ class ImageTransformer(nn.Module): self.head.weight.data.uniform_(-0.1, 0.1) nn.init.kaiming_normal_( - self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu" + self.feature_map_encoding.weight.data, + a=0, + mode="fan_out", + nonlinearity="relu", ) - if self.latent_encoding.bias is not None: + if self.feature_map_encoding.bias is not None: _, fan_out = nn.init._calculate_fan_in_and_fan_out( - self.latent_encoding.weight.data + self.feature_map_encoding.weight.data ) bound = 1 / math.sqrt(fan_out) - nn.init.normal_(self.latent_encoding.bias, -bound, bound) + nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + + @staticmethod + def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: + encoder = OmegaConf.create(encoder) + network_module = importlib.import_module("text_recognizer.networks") + encoder_class = getattr(network_module, encoder.type) + return encoder_class(**encoder.args) def _configure_mapping( - self, mapping: Optional[List[str]] + self, mapping: str ) -> Tuple[List[str], Dict[str, int]]: """Configures mapping.""" - if mapping is None: + if mapping == "emnist": mapping, inverse_mapping, _ = emnist_mapping() return mapping, inverse_mapping @@ -118,14 +130,14 @@ class ImageTransformer(nn.Module): """ # Extract image features. - latent = self.backbone(image) + image_features = self.encoder(image) # Add 2d encoding to the feature maps. - latent = self.latent_encoding(latent) + image_features = self.feature_map_encoding(image_features) # Collapse features maps height and width. - latent = rearrange(latent, "b c h w -> b (h w) c") - return latent + image_features = rearrange(image_features, "b c h w -> b (h w) c") + return image_features def decode(self, memory: Tensor, trg: Tensor) -> Tensor: """Decodes image features with transformer decoder.""" -- cgit v1.2.3-70-g09d2