From 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 5 Apr 2021 23:12:20 +0200
Subject: Add OmegaConf for configs

---
 text_recognizer/models/base.py                | 55 ++++++++++++++++-----------
 text_recognizer/models/transformer.py         | 14 ++++---
 text_recognizer/networks/image_transformer.py | 42 ++++++++++++--------
 3 files changed, 68 insertions(+), 43 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 2d6e435..1004f48 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,33 +1,32 @@
 """Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Union, Tuple, Type
 
 import madgrad
+from omegaconf import OmegaConf
 import pytorch_lightning as pl
 import torch
 from torch import nn
 from torch import Tensor
 import torchmetrics
 
-from text_recognizer import networks
-
 
 class LitBaseModel(pl.LightningModule):
     """Abstract PyTorch Lightning class."""
 
     def __init__(
         self,
-        network_args: Dict,
-        optimizer_args: Dict,
-        lr_scheduler_args: Dict,
-        criterion_args: Dict,
+        network: Type[nn,Module],
+        optimizer: Union[OmegaConf, Dict],
+        lr_scheduler: Union[OmegaConf, Dict],
+        criterion: Union[OmegaConf, Dict],
         monitor: str = "val_loss",
     ) -> None:
         super().__init__()
         self.monitor = monitor
-        self.network = getattr(networks, network_args["type"])(**network_args["args"])
-        self.optimizer_args = optimizer_args
-        self.lr_scheduler_args = lr_scheduler_args
-        self.loss_fn = self.configure_criterion(criterion_args)
+        self.network = network
+        self._optimizer = OmegaConf.create(optimizer)
+        self._lr_scheduler = OmegaConf.create(lr_scheduler)
+        self.loss_fn = self.configure_criterion(criterion)
 
         # Accuracy metric
         self.train_acc = torchmetrics.Accuracy()
@@ -35,27 +34,39 @@ class LitBaseModel(pl.LightningModule):
         self.test_acc = torchmetrics.Accuracy()
 
     @staticmethod
-    def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
+    def configure_criterion(criterion: Union[OmegaConf, Dict]) -> Type[nn.Module]:
         """Returns a loss functions."""
-        args = {} or criterion_args["args"]
-        return getattr(nn, criterion_args["type"])(**args)
+        criterion = OmegaConf.create(criterion)
+        args = {} or criterion.args
+        return getattr(nn, criterion.type)(**args)
 
-    def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
-        """Configures optimizer and lr scheduler."""
-        args = {} or self.optimizer_args["args"]
-        if self.optimizer_args["type"] == "MADGRAD":
-            optimizer = getattr(madgrad, self.optimizer_args["type"])(**args)
+    def _configure_optimizer(self) -> type:
+        """Configures the optimizer."""
+        args = {} or self._optimizer.args
+        if self._optimizer.type == "MADGRAD":
+            optimizer_class = madgrad.MADGRAD
         else:
-            optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
+            optimizer_class = getattr(torch.optim, self._optimizer.type)
+        return optimizer_class(parameters=self.parameters(), **args)
 
+    def _configure_lr_scheduler(self) -> Dict[str, Any]:
+        """Configures the lr scheduler."""
         scheduler = {"monitor": self.monitor}
-        args = {} or self.lr_scheduler_args["args"]
+        args = {} or self._lr_scheduler.args
+
         if "interval" in args:
             scheduler["interval"] = args.pop("interval")
 
         scheduler["scheduler"] = getattr(
-            torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+            torch.optim.lr_scheduler, self._lr_scheduler.type
         )(**args)
+        return scheduler
+
+    def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+        """Configures optimizer and lr scheduler."""
+        optimizer = self._configure_optimizer()
+        scheduler = self._configure_lr_scheduler()
+
         return [optimizer], [scheduler]
 
     def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 285b715..3625ab2 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,6 +1,7 @@
 """PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Union, Tuple
 
+from omegaconf import OmegaConf
 import pytorch_lightning as pl
 import torch
 from torch import nn
@@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel):
 
     def __init__(
         self,
-        network_args: Dict,
-        optimizer_args: Dict,
-        lr_scheduler_args: Dict,
-        criterion_args: Dict,
+        network: Type[nn,Module],
+        optimizer: Union[OmegaConf, Dict],
+        lr_scheduler: Union[OmegaConf, Dict],
+        criterion: Union[OmegaConf, Dict],
         monitor: str = "val_loss",
         mapping: Optional[List[str]] = None,
     ) -> None:
         super().__init__(
-            network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor
+            network, optimizer, lr_scheduler, criterion, monitor
         )
 
         self.mapping, ignore_tokens = self.configure_mapping(mapping)
@@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel):
     @staticmethod
     def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]:
         """Configure mapping."""
+        # TODO: Fix me!!!
         mapping, inverse_mapping, _ = emnist_mapping()
         start_index = inverse_mapping["<s>"]
         end_index = inverse_mapping["<e>"]
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
index b9254c9..aa024e0 100644
--- a/text_recognizer/networks/image_transformer.py
+++ b/text_recognizer/networks/image_transformer.py
@@ -8,10 +8,12 @@ together with the target tokens.
 TODO: Local attention for transformer.j
 
 """
+import importlib
 import math
-from typing import Any, Dict, List, Optional, Sequence, Type
+from typing import Dict, List, Union, Sequence, Tuple, Type
 
 from einops import rearrange
+from omegaconf import OmegaConf
 import torch
 from torch import nn
 from torch import Tensor
@@ -32,8 +34,8 @@ class ImageTransformer(nn.Module):
         self,
         input_shape: Sequence[int],
         output_shape: Sequence[int],
-        backbone: Type[nn.Module],
-        mapping: Optional[List[str]] = None,
+        encoder: Union[OmegaConf, Dict],
+        mapping: str,
         num_decoder_layers: int = 4,
         hidden_dim: int = 256,
         num_heads: int = 4,
@@ -51,8 +53,8 @@ class ImageTransformer(nn.Module):
         self.pad_index = inverse_mapping["<p>"]
 
         # Image backbone
-        self.backbone = backbone
-        self.latent_encoding = PositionalEncoding2D(
+        self.encoder = self._configure_encoder(encoder)
+        self.feature_map_encoding = PositionalEncoding2D(
             hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
         )
 
@@ -86,20 +88,30 @@ class ImageTransformer(nn.Module):
         self.head.weight.data.uniform_(-0.1, 0.1)
 
         nn.init.kaiming_normal_(
-            self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu"
+            self.feature_map_encoding.weight.data,
+            a=0,
+            mode="fan_out",
+            nonlinearity="relu",
         )
-        if self.latent_encoding.bias is not None:
+        if self.feature_map_encoding.bias is not None:
             _, fan_out = nn.init._calculate_fan_in_and_fan_out(
-                self.latent_encoding.weight.data
+                self.feature_map_encoding.weight.data
             )
             bound = 1 / math.sqrt(fan_out)
-            nn.init.normal_(self.latent_encoding.bias, -bound, bound)
+            nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+    @staticmethod
+    def _configure_encoder(encoder: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
+        encoder = OmegaConf.create(encoder)
+        network_module = importlib.import_module("text_recognizer.networks")
+        encoder_class = getattr(network_module, encoder.type)
+        return encoder_class(**encoder.args)
 
     def _configure_mapping(
-        self, mapping: Optional[List[str]]
+        self, mapping: str
     ) -> Tuple[List[str], Dict[str, int]]:
         """Configures mapping."""
-        if mapping is None:
+        if mapping == "emnist":
             mapping, inverse_mapping, _ = emnist_mapping()
         return mapping, inverse_mapping
 
@@ -118,14 +130,14 @@ class ImageTransformer(nn.Module):
 
         """
         # Extract image features.
-        latent = self.backbone(image)
+        image_features = self.encoder(image)
 
         # Add 2d encoding to the feature maps.
-        latent = self.latent_encoding(latent)
+        image_features = self.feature_map_encoding(image_features)
 
         # Collapse features maps height and width.
-        latent = rearrange(latent, "b c h w -> b (h w) c")
-        return latent
+        image_features = rearrange(image_features, "b c h w -> b (h w) c")
+        return image_features
 
     def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
         """Decodes image features with transformer decoder."""
-- 
cgit v1.2.3-70-g09d2