From ff9a21d333f11a42e67c1963ed67de9c0fda87c9 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Thu, 7 Jan 2021 20:10:54 +0100
Subject: Minor updates.

---
 src/text_recognizer/datasets/dataset.py            |   4 +-
 .../datasets/emnist_lines_dataset.py               |   3 +
 src/text_recognizer/datasets/iam_lines_dataset.py  |   2 +
 src/text_recognizer/datasets/transforms.py         |   9 ++
 src/text_recognizer/datasets/util.py               |  15 ++-
 src/text_recognizer/models/__init__.py             |   2 +
 .../models/ctc_transformer_model.py                | 120 +++++++++++++++++
 src/text_recognizer/models/transformer_model.py    |   4 +-
 src/text_recognizer/networks/__init__.py           |   6 +-
 src/text_recognizer/networks/cnn_transformer.py    |  47 +++++--
 src/text_recognizer/networks/metrics.py            |  25 ++--
 src/text_recognizer/networks/residual_network.py   |   4 +-
 .../networks/transformer/__init__.py               |   2 +-
 .../networks/transformer/transformer.py            |  26 +++-
 src/text_recognizer/networks/util.py               |   1 +
 src/text_recognizer/networks/vit.py                | 150 +++++++++++++++++++++
 src/text_recognizer/networks/wide_resnet.py        |  13 +-
 17 files changed, 388 insertions(+), 45 deletions(-)
 create mode 100644 src/text_recognizer/models/ctc_transformer_model.py
 create mode 100644 src/text_recognizer/networks/vit.py

(limited to 'src/text_recognizer')

diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 95063bc..e794605 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -22,6 +22,7 @@ class Dataset(data.Dataset):
         init_token: Optional[str] = None,
         pad_token: Optional[str] = None,
         eos_token: Optional[str] = None,
+        lower: bool = False,
     ) -> None:
         """Initialization of Dataset class.
 
@@ -33,6 +34,7 @@ class Dataset(data.Dataset):
             init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
             pad_token (Optional[str]): String representing the pad token. Defaults to None.
             eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+            lower (bool): Only use lower case letters. Defaults to False.
 
         Raises:
             ValueError: If subsample_fraction is not None and outside the range (0, 1).
@@ -47,7 +49,7 @@ class Dataset(data.Dataset):
         self.subsample_fraction = subsample_fraction
 
         self._mapper = EmnistMapper(
-            init_token=init_token, eos_token=eos_token, pad_token=pad_token
+            init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
         )
         self._input_shape = self._mapper.input_shape
         self._output_shape = self._mapper._num_classes
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index eddf341..1992446 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
         init_token: Optional[str] = None,
         pad_token: Optional[str] = None,
         eos_token: Optional[str] = None,
+        lower: bool = False,
     ) -> None:
         """Set attributes and loads the dataset.
 
@@ -60,6 +61,7 @@ class EmnistLinesDataset(Dataset):
             init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
             pad_token (Optional[str]): String representing the pad token. Defaults to None.
             eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+            lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
 
         """
         self.pad_token = "_" if pad_token is None else pad_token
@@ -72,6 +74,7 @@ class EmnistLinesDataset(Dataset):
             init_token=init_token,
             pad_token=self.pad_token,
             eos_token=eos_token,
+            lower=lower,
         )
 
         # Extract dataset information.
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 5ae142c..1cb84bd 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -35,6 +35,7 @@ class IamLinesDataset(Dataset):
         init_token: Optional[str] = None,
         pad_token: Optional[str] = None,
         eos_token: Optional[str] = None,
+        lower: bool = False,
     ) -> None:
         self.pad_token = "_" if pad_token is None else pad_token
 
@@ -46,6 +47,7 @@ class IamLinesDataset(Dataset):
             init_token=init_token,
             pad_token=pad_token,
             eos_token=eos_token,
+            lower=lower,
         )
 
     @property
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 016ec80..8956b01 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -93,3 +93,12 @@ class Squeeze:
     def __call__(self, x: Tensor) -> Tensor:
         """Removes first dim."""
         return x.squeeze(0)
+
+
+class ToLower:
+    """Converts target to lower case."""
+
+    def __call__(self, target: Tensor) -> Tensor:
+        """Corrects index value in target tensor."""
+        device = target.device
+        return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index bf5e772..da87756 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,17 +1,14 @@
 """Util functions for datasets."""
 import hashlib
-import importlib
 import json
 import os
 from pathlib import Path
 import string
-from typing import Callable, Dict, List, Optional, Type, Union
-from urllib.request import urlopen, urlretrieve
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
 
-import cv2
 from loguru import logger
 import numpy as np
-from PIL import Image
 import torch
 from torch import Tensor
 from torchvision.datasets import EMNIST
@@ -50,11 +47,13 @@ class EmnistMapper:
         pad_token: str,
         init_token: Optional[str] = None,
         eos_token: Optional[str] = None,
+        lower: bool = False,
     ) -> None:
         """Loads the emnist essentials file with the mapping and input shape."""
         self.init_token = init_token
         self.pad_token = pad_token
         self.eos_token = eos_token
+        self.lower = lower
 
         self.essentials = self._load_emnist_essentials()
         # Load dataset information.
@@ -120,6 +119,12 @@ class EmnistMapper:
     def _augment_emnist_mapping(self) -> None:
         """Augment the mapping with extra symbols."""
         # Extra symbols in IAM dataset
+        if self.lower:
+            self._mapping = {
+                k: str(v)
+                for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+            }
+
         extra_symbols = [
             " ",
             "!",
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a645cec..eb5dbce 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,12 +2,14 @@
 from .base import Model
 from .character_model import CharacterModel
 from .crnn_model import CRNNModel
+from .ctc_transformer_model import CTCTransformerModel
 from .segmentation_model import SegmentationModel
 from .transformer_model import TransformerModel
 
 __all__ = [
     "CharacterModel",
     "CRNNModel",
+    "CTCTransformerModel",
     "Model",
     "SegmentationModel",
     "TransformerModel",
diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py
new file mode 100644
index 0000000..25925f2
--- /dev/null
+++ b/src/text_recognizer/models/ctc_transformer_model.py
@@ -0,0 +1,120 @@
+"""Defines the CTC Transformer Model class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CTCTransformerModel(Model):
+    """Model for predicting a sequence of characters from an image of a text line."""
+
+    def __init__(
+        self,
+        network_fn: Type[nn.Module],
+        dataset: Type[Dataset],
+        network_args: Optional[Dict] = None,
+        dataset_args: Optional[Dict] = None,
+        metrics: Optional[Dict] = None,
+        criterion: Optional[Callable] = None,
+        criterion_args: Optional[Dict] = None,
+        optimizer: Optional[Callable] = None,
+        optimizer_args: Optional[Dict] = None,
+        lr_scheduler: Optional[Callable] = None,
+        lr_scheduler_args: Optional[Dict] = None,
+        swa_args: Optional[Dict] = None,
+        device: Optional[str] = None,
+    ) -> None:
+        super().__init__(
+            network_fn,
+            dataset,
+            network_args,
+            dataset_args,
+            metrics,
+            criterion,
+            criterion_args,
+            optimizer,
+            optimizer_args,
+            lr_scheduler,
+            lr_scheduler_args,
+            swa_args,
+            device,
+        )
+        self.pad_token = dataset_args["args"]["pad_token"]
+        self.lower = dataset_args["args"]["lower"]
+
+        if self._mapper is None:
+            self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
+
+        self.tensor_transform = ToTensor()
+
+    def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+        """Computes the CTC loss.
+
+        Args:
+            output (Tensor): Model predictions.
+            targets (Tensor): Correct output sequence.
+
+        Returns:
+            Tensor: The CTC loss.
+
+        """
+        # Input lengths on the form [T, B]
+        input_lengths = torch.full(
+            size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+        )
+
+        # Configure target tensors for ctc loss.
+        targets_ = Tensor([]).to(self.device)
+        target_lengths = []
+        for t in targets:
+            # Remove padding symbol as it acts as the blank symbol.
+            t = t[t < 53]
+            targets_ = torch.cat([targets_, t])
+            target_lengths.append(len(t))
+
+        targets = targets_.type(dtype=torch.long)
+        target_lengths = (
+            torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+        )
+
+        return self._criterion(output, targets, input_lengths, target_lengths)
+
+    @torch.no_grad()
+    def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+        """Predict on a single input."""
+        self.eval()
+
+        if image.dtype == np.uint8:
+            # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+            image = self.tensor_transform(image)
+
+        # Rescale image between 0 and 1.
+        if image.dtype == torch.uint8:
+            # If the image is an unscaled tensor.
+            image = image.type("torch.FloatTensor") / 255
+
+        # Put the image tensor on the device the model weights are on.
+        image = image.to(self.device)
+        log_probs = self.forward(image)
+
+        raw_pred, _ = greedy_decoder(
+            predictions=log_probs,
+            character_mapper=self.mapper,
+            blank_label=53,
+            collapse_repeated=True,
+        )
+
+        log_probs, _ = log_probs.max(dim=2)
+
+        predicted_characters = "".join(raw_pred[0])
+        confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+        return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index a912122..12e497f 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -50,13 +50,15 @@ class TransformerModel(Model):
         self.init_token = dataset_args["args"]["init_token"]
         self.pad_token = dataset_args["args"]["pad_token"]
         self.eos_token = dataset_args["args"]["eos_token"]
-        self.max_len = 120
+        self.lower = dataset_args["args"]["lower"]
+        self.max_len = 100
 
         if self._mapper is None:
             self._mapper = EmnistMapper(
                 init_token=self.init_token,
                 pad_token=self.pad_token,
                 eos_token=self.eos_token,
+                lower=self.lower,
             )
         self.tensor_transform = ToTensor()
 
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index f958672..2b624bb 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -3,19 +3,18 @@ from .cnn_transformer import CNNTransformer
 from .crnn import ConvolutionalRecurrentNetwork
 from .ctc import greedy_decoder
 from .densenet import DenseNet
-from .fcn import FCN
 from .lenet import LeNet
-from .metrics import accuracy, accuracy_ignore_pad, cer, wer
+from .metrics import accuracy, cer, wer
 from .mlp import MLP
 from .residual_network import ResidualNetwork, ResidualNetworkEncoder
 from .transformer import Transformer
 from .unet import UNet
 from .util import sliding_window
+from .vit import ViT
 from .wide_resnet import WideResidualNetwork
 
 __all__ = [
     "accuracy",
-    "accuracy_ignore_pad",
     "cer",
     "CNNTransformer",
     "ConvolutionalRecurrentNetwork",
@@ -29,6 +28,7 @@ __all__ = [
     "sliding_window",
     "UNet",
     "Transformer",
+    "ViT",
     "wer",
     "WideResidualNetwork",
 ]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index b2b74b3..caa73e3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,12 +1,13 @@
 """A CNN-Transformer for image to text recognition."""
 from typing import Dict, Optional, Tuple
 
-from einops import rearrange
+from einops import rearrange, repeat
 import torch
 from torch import nn
 from torch import Tensor
 
 from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
 from text_recognizer.networks.util import configure_backbone
 
 
@@ -24,15 +25,21 @@ class CNNTransformer(nn.Module):
         expansion_dim: int,
         dropout_rate: float,
         trg_pad_index: int,
+        max_len: int,
         backbone: str,
         backbone_args: Optional[Dict] = None,
         activation: str = "gelu",
     ) -> None:
         super().__init__()
         self.trg_pad_index = trg_pad_index
+        self.vocab_size = vocab_size
         self.backbone = configure_backbone(backbone, backbone_args)
-        self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
-        self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+        self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+        self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+        nn.init.normal_(self.character_embedding.weight, std=0.02)
 
         self.adaptive_pool = (
             nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
@@ -48,7 +55,11 @@ class CNNTransformer(nn.Module):
             activation,
         )
 
-        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+        self.head = nn.Sequential(
+            # nn.Linear(hidden_dim, hidden_dim * 2),
+            # activation_function(activation),
+            nn.Linear(hidden_dim, vocab_size),
+        )
 
     def _create_trg_mask(self, trg: Tensor) -> Tensor:
         # Move this outside the transformer.
@@ -96,7 +107,21 @@ class CNNTransformer(nn.Module):
         else:
             src = rearrange(src, "b c h w -> b (w h) c")
 
-        src = self.position_encoding(src)
+        b, t, _ = src.shape
+
+        # Insert sos and eos token.
+        # sos_token = self.character_embedding(
+        #    torch.Tensor([self.vocab_size - 2]).long().to(src.device)
+        # )
+        # eos_token = self.character_embedding(
+        #    torch.Tensor([self.vocab_size - 1]).long().to(src.device)
+        # )
+
+        # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
+        # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
+        # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
+        # src = torch.cat((sos_tokens, src), dim=1)
+        src += self.src_position_embedding[:, :t]
 
         return src
 
@@ -111,20 +136,22 @@ class CNNTransformer(nn.Module):
 
         """
         trg = self.character_embedding(trg.long())
-        trg = self.position_encoding(trg)
+        trg = self.trg_position_encoding(trg)
         return trg
 
-    def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+    def decode_image_features(
+        self, image_features: Tensor, trg: Optional[Tensor] = None
+    ) -> Tensor:
         """Takes images features from the backbone and decodes them with the transformer."""
         trg_mask = self._create_trg_mask(trg)
         trg = self.target_embedding(trg)
-        out = self.transformer(h, trg, trg_mask=trg_mask)
+        out = self.transformer(image_features, trg, trg_mask=trg_mask)
 
         logits = self.head(out)
         return logits
 
     def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
         """Forward pass with CNN transfomer."""
-        h = self.extract_image_features(x)
-        logits = self.decode_image_features(h, trg)
+        image_features = self.extract_image_features(x)
+        logits = self.decode_image_features(image_features, trg)
         return logits
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index af9adb5..ffad792 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -6,28 +6,13 @@ from torch import Tensor
 from text_recognizer.networks import greedy_decoder
 
 
-def accuracy_ignore_pad(
-    output: Tensor,
-    target: Tensor,
-    pad_index: int = 79,
-    eos_index: int = 81,
-    seq_len: int = 97,
-) -> float:
-    """Sets all predictions after eos to pad."""
-    start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
-    end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
-    for start, stop in zip(start_indices, end_indices):
-        output[start + 1 : stop] = pad_index
-
-    return accuracy(output, target)
-
-
-def accuracy(outputs: Tensor, labels: Tensor,) -> float:
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
     """Computes the accuracy.
 
     Args:
         outputs (Tensor): The output from the network.
         labels (Tensor): Ground truth labels.
+        pad_index (int): Padding index.
 
     Returns:
         float: The accuracy for the batch.
@@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float:
 
     _, predicted = torch.max(outputs, dim=-1)
 
+    # Mask out the pad tokens
+    mask = labels != pad_index
+
+    predicted *= mask
+    labels *= mask
+
     acc = (predicted == labels).sum().float() / labels.shape[0]
     acc = acc.item()
     return acc
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index e397224..c33f419 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -221,8 +221,8 @@ class ResidualNetworkEncoder(nn.Module):
             nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=self.block_sizes[0],
-                kernel_size=3,
-                stride=1,
+                kernel_size=7,
+                stride=2,
                 padding=1,
                 bias=False,
             ),
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
index 020a917..9febc88 100644
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
 """Transformer modules."""
 from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, Transformer
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index c6e943e..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -6,11 +6,25 @@ import numpy as np
 import torch
 from torch import nn
 from torch import Tensor
+import torch.nn.functional as F
 
 from text_recognizer.networks.transformer.attention import MultiHeadAttention
 from text_recognizer.networks.util import activation_function
 
 
+class GEGLU(nn.Module):
+    """GLU activation for improving feedforward activations."""
+
+    def __init__(self, dim_in: int, dim_out: int) -> None:
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward propagation."""
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
 def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
     return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
 
@@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module):
         activation: str = "relu",
     ) -> None:
         super().__init__()
+
+        in_projection = (
+            nn.Sequential(
+                nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+            )
+            if activation != "glu"
+            else GEGLU(hidden_dim, expansion_dim)
+        )
+
         self.layer = nn.Sequential(
-            nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
-            activation_function(activation),
+            in_projection,
             nn.Dropout(p=dropout_rate),
             nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
         )
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index e2d7955..711a952 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -39,6 +39,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
         [
             ["elu", nn.ELU(inplace=True)],
             ["gelu", nn.GELU()],
+            ["glu", nn.GLU()],
             ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
             ["none", nn.Identity()],
             ["relu", nn.ReLU(inplace=True)],
diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py
new file mode 100644
index 0000000..efb3701
--- /dev/null
+++ b/src/text_recognizer/networks/vit.py
@@ -0,0 +1,150 @@
+"""A Vision Transformer.
+
+Inspired by:
+https://openreview.net/pdf?id=YicbFdNTTy
+
+"""
+from typing import Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import Transformer
+
+
+class ViT(nn.Module):
+    """Transfomer for image to sequence prediction."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        expansion_dim: int,
+        patch_dim: Tuple[int, int],
+        image_size: Tuple[int, int],
+        dropout_rate: float,
+        trg_pad_index: int,
+        max_len: int,
+        activation: str = "gelu",
+    ) -> None:
+        super().__init__()
+
+        self.trg_pad_index = trg_pad_index
+        self.patch_dim = patch_dim
+        self.num_patches = image_size[-1] // self.patch_dim[1]
+
+        # Encoder
+        self.patch_to_embedding = nn.Linear(
+            self.patch_dim[0] * self.patch_dim[1], hidden_dim
+        )
+        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
+        self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+        self.dropout = nn.Dropout(dropout_rate)
+        self._init()
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+    def _init(self) -> None:
+        nn.init.normal_(self.character_embedding.weight, std=0.02)
+        # nn.init.normal_(self.pos_embedding.weight, std=0.02)
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def extract_image_features(self, src: Tensor) -> Tensor:
+        """Extracts image features with a backbone neural network.
+
+        It seem like the winning idea was to swap channels and width dimension and collapse
+        the height dimension. The transformer is learning like a baby with this implementation!!! :D
+        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: A input src to the transformer.
+
+        """
+        # If batch dimension is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+
+        patches = rearrange(
+            src,
+            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+            p1=self.patch_dim[0],
+            p2=self.patch_dim[1],
+        )
+
+        # From patches to encoded sequence.
+        x = self.patch_to_embedding(patches)
+        b, n, _ = x.shape
+        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
+        x = torch.cat((cls_tokens, x), dim=1)
+        x += self.pos_embedding[:, : (n + 1)]
+        x = self.dropout(x)
+
+        return x
+
+    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+        """
+        _, n = trg.shape
+        trg = self.character_embedding(trg.long())
+        trg += self.pos_embedding[:, :n]
+        return trg
+
+    def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Takes images features from the backbone and decodes them with the transformer."""
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.target_embedding(trg)
+        out = self.transformer(h, trg, trg_mask=trg_mask)
+
+        logits = self.head(out)
+        return logits
+
+    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Forward pass with CNN transfomer."""
+        h = self.extract_image_features(x)
+        logits = self.decode_image_features(h, trg)
+        return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 28f3380..b767778 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module):
         dropout_rate: float = 0.0,
         num_layers: int = 3,
         block: Type[nn.Module] = WideBlock,
+        num_stages: Optional[List[int]] = None,
         activation: str = "relu",
         use_decoder: bool = True,
     ) -> None:
@@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module):
             dropout_rate (float): The dropout rate. Defaults to 0.0.
             num_layers (int): Number of layers of blocks. Defaults to 3.
             block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+            num_stages (List[int]): If given, will use these channel values. Defaults to None.
             activation (str): Name of the activation to use. Defaults to "relu".
             use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
                 latent vector. Defaults to True.
@@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module):
         self.dropout_rate = dropout_rate
         self.activation = activation_function(activation)
 
-        self.num_stages = [self.in_planes] + [
-            self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
-        ]
+        if num_stages is None:
+            self.num_stages = [self.in_planes] + [
+                self.in_planes * 2 ** n * self.width_factor
+                for n in range(self.num_layers)
+            ]
+        else:
+            self.num_stages = [self.in_planes] + num_stages
+
         self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
         self.strides = [1] + [2] * (self.num_layers - 1)
 
-- 
cgit v1.2.3-70-g09d2