From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 20 Mar 2021 18:09:06 +0100
Subject: Inital commit for refactoring to lightning

---
 text_recognizer/networks/__init__.py               |  43 +++
 text_recognizer/networks/beam.py                   |  83 +++++
 text_recognizer/networks/cnn.py                    | 101 +++++
 text_recognizer/networks/cnn_transformer.py        | 158 ++++++++
 text_recognizer/networks/crnn.py                   | 110 ++++++
 text_recognizer/networks/ctc.py                    |  58 +++
 text_recognizer/networks/densenet.py               | 225 +++++++++++
 text_recognizer/networks/lenet.py                  |  68 ++++
 text_recognizer/networks/loss/__init__.py          |   2 +
 text_recognizer/networks/loss/loss.py              |  69 ++++
 text_recognizer/networks/metrics.py                | 123 +++++++
 text_recognizer/networks/mlp.py                    |  73 ++++
 text_recognizer/networks/residual_network.py       | 310 ++++++++++++++++
 text_recognizer/networks/stn.py                    |  44 +++
 text_recognizer/networks/transducer/__init__.py    |   3 +
 text_recognizer/networks/transducer/tds_conv.py    | 208 +++++++++++
 text_recognizer/networks/transducer/test.py        |  60 +++
 text_recognizer/networks/transducer/transducer.py  | 410 +++++++++++++++++++++
 text_recognizer/networks/transformer/__init__.py   |   3 +
 text_recognizer/networks/transformer/attention.py  |  93 +++++
 .../networks/transformer/positional_encoding.py    |  32 ++
 .../networks/transformer/transformer.py            | 264 +++++++++++++
 text_recognizer/networks/unet.py                   | 255 +++++++++++++
 text_recognizer/networks/util.py                   |  89 +++++
 text_recognizer/networks/vit.py                    | 150 ++++++++
 text_recognizer/networks/vq_transformer.py         | 150 ++++++++
 text_recognizer/networks/vqvae/__init__.py         |   5 +
 text_recognizer/networks/vqvae/decoder.py          | 133 +++++++
 text_recognizer/networks/vqvae/encoder.py          | 147 ++++++++
 text_recognizer/networks/vqvae/vector_quantizer.py | 119 ++++++
 text_recognizer/networks/vqvae/vqvae.py            |  74 ++++
 text_recognizer/networks/wide_resnet.py            | 221 +++++++++++
 32 files changed, 3883 insertions(+)
 create mode 100644 text_recognizer/networks/__init__.py
 create mode 100644 text_recognizer/networks/beam.py
 create mode 100644 text_recognizer/networks/cnn.py
 create mode 100644 text_recognizer/networks/cnn_transformer.py
 create mode 100644 text_recognizer/networks/crnn.py
 create mode 100644 text_recognizer/networks/ctc.py
 create mode 100644 text_recognizer/networks/densenet.py
 create mode 100644 text_recognizer/networks/lenet.py
 create mode 100644 text_recognizer/networks/loss/__init__.py
 create mode 100644 text_recognizer/networks/loss/loss.py
 create mode 100644 text_recognizer/networks/metrics.py
 create mode 100644 text_recognizer/networks/mlp.py
 create mode 100644 text_recognizer/networks/residual_network.py
 create mode 100644 text_recognizer/networks/stn.py
 create mode 100644 text_recognizer/networks/transducer/__init__.py
 create mode 100644 text_recognizer/networks/transducer/tds_conv.py
 create mode 100644 text_recognizer/networks/transducer/test.py
 create mode 100644 text_recognizer/networks/transducer/transducer.py
 create mode 100644 text_recognizer/networks/transformer/__init__.py
 create mode 100644 text_recognizer/networks/transformer/attention.py
 create mode 100644 text_recognizer/networks/transformer/positional_encoding.py
 create mode 100644 text_recognizer/networks/transformer/transformer.py
 create mode 100644 text_recognizer/networks/unet.py
 create mode 100644 text_recognizer/networks/util.py
 create mode 100644 text_recognizer/networks/vit.py
 create mode 100644 text_recognizer/networks/vq_transformer.py
 create mode 100644 text_recognizer/networks/vqvae/__init__.py
 create mode 100644 text_recognizer/networks/vqvae/decoder.py
 create mode 100644 text_recognizer/networks/vqvae/encoder.py
 create mode 100644 text_recognizer/networks/vqvae/vector_quantizer.py
 create mode 100644 text_recognizer/networks/vqvae/vqvae.py
 create mode 100644 text_recognizer/networks/wide_resnet.py

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
new file mode 100644
index 0000000..1521355
--- /dev/null
+++ b/text_recognizer/networks/__init__.py
@@ -0,0 +1,43 @@
+"""Network modules."""
+from .cnn import CNN
+from .cnn_transformer import CNNTransformer
+from .crnn import ConvolutionalRecurrentNetwork
+from .ctc import greedy_decoder
+from .densenet import DenseNet
+from .lenet import LeNet
+from .metrics import accuracy, cer, wer
+from .mlp import MLP
+from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .transducer import load_transducer_loss, TDS2d
+from .transformer import Transformer
+from .unet import UNet
+from .util import sliding_window
+from .vit import ViT
+from .vq_transformer import VQTransformer
+from .vqvae import VQVAE
+from .wide_resnet import WideResidualNetwork
+
+__all__ = [
+    "accuracy",
+    "cer",
+    "CNN",
+    "CNNTransformer",
+    "ConvolutionalRecurrentNetwork",
+    "DenseNet",
+    "FCN",
+    "greedy_decoder",
+    "MLP",
+    "LeNet",
+    "load_transducer_loss",
+    "ResidualNetwork",
+    "ResidualNetworkEncoder",
+    "sliding_window",
+    "UNet",
+    "TDS2d",
+    "Transformer",
+    "ViT",
+    "VQTransformer",
+    "VQVAE",
+    "wer",
+    "WideResidualNetwork",
+]
diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py
new file mode 100644
index 0000000..dccccdb
--- /dev/null
+++ b/text_recognizer/networks/beam.py
@@ -0,0 +1,83 @@
+"""Implementation of beam search decoder for a sequence to sequence network.
+
+Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
+
+"""
+# from typing import List
+# from Queue import PriorityQueue
+
+# from loguru import logger
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import torch.nn.functional as F
+
+
+# class Node:
+#     def __init__(
+#         self, parent: Node, target_index: int, log_prob: Tensor, length: int
+#     ) -> None:
+#         self.parent = parent
+#         self.target_index = target_index
+#         self.log_prob = log_prob
+#         self.length = length
+#         self.reward = 0.0
+
+#     def eval(self, alpha: float = 1.0) -> Tensor:
+#         return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
+
+
+# @torch.no_grad()
+# def beam_decoder(
+#     network, mapper, device, memory: Tensor = None, max_len: int = 97,
+# ) -> Tensor:
+#     beam_width = 10
+#     topk = 1  # How many sentences to generate.
+
+#     trg_indices = [mapper(mapper.init_token)]
+
+#     end_nodes = []
+
+#     node = Node(None, trg_indices, 0, 1)
+#     nodes = PriorityQueue()
+
+#     nodes.put((node.eval(), node))
+#     q_size = 1
+
+#     # Beam search
+#     for _ in range(max_len):
+#         if q_size > 2000:
+#             logger.warning("Could not decoder input")
+#             break
+
+#         # Fetch the best node.
+#         score, n = nodes.get()
+#         decoder_input = n.target_index
+
+#         if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
+#             end_nodes.append((score, n))
+
+#             # If we reached the maximum number of sentences required.
+#             if len(end_nodes) >= 1:
+#                 break
+#             else:
+#                 continue
+
+#         # Forward pass with transformer.
+#         trg = torch.tensor(trg_indices, device=device)[None, :].long()
+#         trg = network.target_embedding(trg)
+#         logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
+#         log_prob = F.log_softmax(logits, dim=2)
+
+#         log_prob, indices = torch.topk(log_prob, beam_width)
+
+#         for new_k in range(beam_width):
+#             # TODO: continue from here
+#             token_index = indices[0][new_k].view(1, -1)
+#             log_p = log_prob[0][new_k].item()
+
+#             node = Node()
+
+#             pass
+
+#     pass
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..1807bb9
--- /dev/null
+++ b/text_recognizer/networks/cnn.py
@@ -0,0 +1,101 @@
+"""Implementation of a simple backbone cnn network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class CNN(nn.Module):
+    """LeNet network for character prediction."""
+
+    def __init__(
+        self,
+        channels: Tuple[int, ...] = (1, 32, 64, 128),
+        kernel_sizes: Tuple[int, ...] = (4, 4, 4),
+        strides: Tuple[int, ...] = (2, 2, 2),
+        max_pool_kernel: int = 2,
+        dropout_rate: float = 0.2,
+        activation: Optional[str] = "relu",
+    ) -> None:
+        """Initialization of the LeNet network.
+
+        Args:
+            channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+            kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+            strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
+            max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
+            dropout_rate (float): The dropout rate. Defaults to 0.2.
+            activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+        Raises:
+            RuntimeError: if the number of hyperparameters does not match in length.
+
+        """
+        super().__init__()
+
+        if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
+            raise RuntimeError("The number of the hyperparameters does not match.")
+
+        self.cnn = self._build_network(
+            channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
+        )
+
+    def _build_network(
+        self,
+        channels: Tuple[int, ...],
+        kernel_sizes: Tuple[int, ...],
+        strides: Tuple[int, ...],
+        max_pool_kernel: int,
+        dropout_rate: float,
+        activation: str,
+    ) -> nn.Sequential:
+        # Load activation function.
+        activation_fn = activation_function(activation)
+
+        channels = list(channels)
+        in_channels = channels.pop(0)
+        configuration = zip(channels, kernel_sizes, strides)
+
+        modules = nn.ModuleList([])
+
+        for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+            # Add max pool to reduce output size.
+            if i == len(channels) // 2:
+                modules.append(nn.MaxPool2d(max_pool_kernel))
+            if i == 0:
+                modules.append(
+                    nn.Conv2d(
+                        in_channels, out_channels, kernel_size, stride=stride, padding=1
+                    )
+                )
+            else:
+                modules.append(
+                    nn.Sequential(
+                        activation_fn,
+                        nn.BatchNorm2d(in_channels),
+                        nn.Conv2d(
+                            in_channels,
+                            out_channels,
+                            kernel_size,
+                            stride=stride,
+                            padding=1,
+                        ),
+                    )
+                )
+
+            if dropout_rate:
+                modules.append(nn.Dropout2d(p=dropout_rate))
+
+            in_channels = out_channels
+
+        return nn.Sequential(*modules)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """The feedforward pass."""
+        # If batch dimenstion is missing, it needs to be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.cnn(x)
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..9150b55
--- /dev/null
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,158 @@
+"""A CNN-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+    """CNN+Transfomer for image to sequence prediction."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        adaptive_pool_dim: Tuple,
+        expansion_dim: int,
+        dropout_rate: float,
+        trg_pad_index: int,
+        max_len: int,
+        backbone: str,
+        backbone_args: Optional[Dict] = None,
+        activation: str = "gelu",
+        pool_kernel: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        super().__init__()
+        self.trg_pad_index = trg_pad_index
+        self.vocab_size = vocab_size
+        self.backbone = configure_backbone(backbone, backbone_args)
+
+        if pool_kernel is not None:
+            self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+        else:
+            self.max_pool = None
+
+        self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+        self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+        self.pos_dropout = nn.Dropout(p=dropout_rate)
+        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+        nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+        self.adaptive_pool = (
+            nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+        )
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(
+            # nn.Linear(hidden_dim, hidden_dim * 2),
+            # activation_function(activation),
+            nn.Linear(hidden_dim, vocab_size),
+        )
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def extract_image_features(self, src: Tensor) -> Tensor:
+        """Extracts image features with a backbone neural network.
+
+        It seem like the winning idea was to swap channels and width dimension and collapse
+        the height dimension. The transformer is learning like a baby with this implementation!!! :D
+        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: A input src to the transformer.
+
+        """
+        # If batch dimension is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+
+        src = self.backbone(src)
+
+        if self.max_pool is not None:
+            src = self.max_pool(src)
+
+        if self.adaptive_pool is not None and len(src.shape) == 4:
+            src = rearrange(src, "b c h w -> b w c h")
+            src = self.adaptive_pool(src)
+            src = src.squeeze(3)
+        elif len(src.shape) == 4:
+            src = rearrange(src, "b c h w -> b (h w) c")
+
+        b, t, _ = src.shape
+
+        src += self.src_position_embedding[:, :t]
+        src = self.pos_dropout(src)
+
+        return src
+
+    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+        """
+        trg = self.character_embedding(trg.long())
+        trg = self.trg_position_encoding(trg)
+        return trg
+
+    def decode_image_features(
+        self, image_features: Tensor, trg: Optional[Tensor] = None
+    ) -> Tensor:
+        """Takes images features from the backbone and decodes them with the transformer."""
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.target_embedding(trg)
+        out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+        logits = self.head(out)
+        return logits
+
+    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Forward pass with CNN transfomer."""
+        image_features = self.extract_image_features(x)
+        logits = self.decode_image_features(image_features, trg)
+        return logits
diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..778e232
--- /dev/null
+++ b/text_recognizer/networks/crnn.py
@@ -0,0 +1,110 @@
+"""CRNN for handwritten text recognition."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+    """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+    def __init__(
+        self,
+        backbone: str,
+        backbone_args: Dict = None,
+        input_size: int = 128,
+        hidden_size: int = 128,
+        bidirectional: bool = False,
+        num_layers: int = 1,
+        num_classes: int = 80,
+        patch_size: Tuple[int, int] = (28, 28),
+        stride: Tuple[int, int] = (1, 14),
+        recurrent_cell: str = "lstm",
+        avg_pool: bool = False,
+        use_sliding_window: bool = True,
+    ) -> None:
+        super().__init__()
+        self.backbone_args = backbone_args or {}
+        self.patch_size = patch_size
+        self.stride = stride
+        self.sliding_window = (
+            self._configure_sliding_window() if use_sliding_window else None
+        )
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.backbone = configure_backbone(backbone, backbone_args)
+        self.bidirectional = bidirectional
+        self.avg_pool = avg_pool
+
+        if recurrent_cell.upper() in ["LSTM", "GRU"]:
+            recurrent_cell = getattr(nn, recurrent_cell)
+        else:
+            logger.warning(
+                f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+            )
+            recurrent_cell = nn.LSTM
+
+        self.rnn = recurrent_cell(
+            input_size=self.input_size,
+            hidden_size=self.hidden_size,
+            bidirectional=bidirectional,
+            num_layers=num_layers,
+        )
+
+        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+        self.decoder = nn.Sequential(
+            nn.Linear(in_features=decoder_size, out_features=num_classes),
+            nn.LogSoftmax(dim=2),
+        )
+
+    def _configure_sliding_window(self) -> nn.Sequential:
+        return nn.Sequential(
+            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+            Rearrange(
+                "b (c h w) t -> b t c h w",
+                h=self.patch_size[0],
+                w=self.patch_size[1],
+                c=1,
+            ),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+
+        if self.sliding_window is not None:
+            # Create image patches with a sliding window kernel.
+            x = self.sliding_window(x)
+
+            # Rearrange from a sequence of patches for feedforward network.
+            b, t = x.shape[:2]
+            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+            x = self.backbone(x)
+
+            # Average pooling.
+            if self.avg_pool:
+                x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+            else:
+                x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+        else:
+            # Encode the entire image with a CNN, and use the channels as temporal dimension.
+            x = self.backbone(x)
+            x = rearrange(x, "b c h w -> b w c h")
+            if self.adaptive_pool is not None:
+                x = self.adaptive_pool(x)
+            x = x.squeeze(3)
+
+        # Sequence predictions.
+        x, _ = self.rnn(x)
+
+        # Sequence to classification layer.
+        x = self.decoder(x)
+        return x
diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py
new file mode 100644
index 0000000..af9b700
--- /dev/null
+++ b/text_recognizer/networks/ctc.py
@@ -0,0 +1,58 @@
+"""Decodes the CTC output."""
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+def greedy_decoder(
+    predictions: Tensor,
+    targets: Optional[Tensor] = None,
+    target_lengths: Optional[Tensor] = None,
+    character_mapper: Optional[Callable] = None,
+    blank_label: int = 79,
+    collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+    """Greedy CTC decoder.
+
+    Args:
+        predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+        targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+        target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+        character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters.  Defaults
+            to None.
+        blank_label (int): The blank character to be ignored. Defaults to 80.
+        collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+    Returns:
+        Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+    """
+
+    if character_mapper is None:
+        character_mapper = EmnistMapper(pad_token="_")  # noqa: S106
+
+    predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+    decoded_predictions = []
+    decoded_targets = []
+    for i, prediction in enumerate(predictions):
+        decoded_prediction = []
+        decoded_target = []
+        if targets is not None and target_lengths is not None:
+            for target_index in targets[i][: target_lengths[i]]:
+                if target_index == blank_label:
+                    continue
+                decoded_target.append(character_mapper(int(target_index)))
+            decoded_targets.append(decoded_target)
+        for j, index in enumerate(prediction):
+            if index != blank_label:
+                if collapse_repeated and j != 0 and index == prediction[j - 1]:
+                    continue
+                decoded_prediction.append(index.item())
+        decoded_predictions.append(
+            [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+        )
+    return decoded_predictions, decoded_targets
diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..7dc58d9
--- /dev/null
+++ b/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+    """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        growth_rate: int,
+        bn_size: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        activation_fn = activation_function(activation)
+        self.dense_layer = [
+            nn.BatchNorm2d(in_channels),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=bn_size * growth_rate,
+                kernel_size=1,
+                stride=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(bn_size * growth_rate),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=bn_size * growth_rate,
+                out_channels=growth_rate,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
+        ]
+        if dropout_rate:
+            self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+        self.dense_layer = nn.Sequential(*self.dense_layer)
+
+    def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+        if isinstance(x, list):
+            x = torch.cat(x, 1)
+        return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+    def __init__(
+        self,
+        num_layers: int,
+        in_channels: int,
+        bn_size: int,
+        growth_rate: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.dense_block = self._build_dense_blocks(
+            num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
+        )
+
+    def _build_dense_blocks(
+        self,
+        num_layers: int,
+        in_channels: int,
+        bn_size: int,
+        growth_rate: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> nn.ModuleList:
+        dense_block = []
+        for i in range(num_layers):
+            dense_block.append(
+                _DenseLayer(
+                    in_channels=in_channels + i * growth_rate,
+                    growth_rate=growth_rate,
+                    bn_size=bn_size,
+                    dropout_rate=dropout_rate,
+                    activation=activation,
+                )
+            )
+        return nn.ModuleList(dense_block)
+
+    def forward(self, x: Tensor) -> Tensor:
+        feature_maps = [x]
+        for layer in self.dense_block:
+            x = layer(feature_maps)
+            feature_maps.append(x)
+        return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+    def __init__(
+        self, in_channels: int, out_channels: int, activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        activation_fn = activation_function(activation)
+        self.transition = nn.Sequential(
+            nn.BatchNorm2d(in_channels),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                stride=1,
+                bias=False,
+            ),
+            nn.AvgPool2d(kernel_size=2, stride=2),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.transition(x)
+
+
+class DenseNet(nn.Module):
+    """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+    def __init__(
+        self,
+        growth_rate: int = 32,
+        block_config: List[int] = (6, 12, 24, 16),
+        in_channels: int = 1,
+        base_channels: int = 64,
+        num_classes: int = 80,
+        bn_size: int = 4,
+        dropout_rate: float = 0,
+        classifier: bool = True,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.densenet = self._configure_densenet(
+            in_channels,
+            base_channels,
+            num_classes,
+            growth_rate,
+            block_config,
+            bn_size,
+            dropout_rate,
+            classifier,
+            activation,
+        )
+
+    def _configure_densenet(
+        self,
+        in_channels: int,
+        base_channels: int,
+        num_classes: int,
+        growth_rate: int,
+        block_config: List[int],
+        bn_size: int,
+        dropout_rate: float,
+        classifier: bool,
+        activation: str,
+    ) -> nn.Sequential:
+        activation_fn = activation_function(activation)
+        densenet = [
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=base_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(base_channels),
+            activation_fn,
+        ]
+
+        num_features = base_channels
+
+        for i, num_layers in enumerate(block_config):
+            densenet.append(
+                _DenseBlock(
+                    num_layers=num_layers,
+                    in_channels=num_features,
+                    bn_size=bn_size,
+                    growth_rate=growth_rate,
+                    dropout_rate=dropout_rate,
+                    activation=activation,
+                )
+            )
+            num_features = num_features + num_layers * growth_rate
+            if i != len(block_config) - 1:
+                densenet.append(
+                    _Transition(
+                        in_channels=num_features,
+                        out_channels=num_features // 2,
+                        activation=activation,
+                    )
+                )
+                num_features = num_features // 2
+
+        densenet.append(activation_fn)
+
+        if classifier:
+            densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+            densenet.append(Rearrange("b c h w -> b (c h w)"))
+            densenet.append(
+                nn.Linear(in_features=num_features, out_features=num_classes)
+            )
+
+        return nn.Sequential(*densenet)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass of Densenet."""
+        # If batch dimenstion is missing, it will be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.densenet(x)
diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py
new file mode 100644
index 0000000..527e1a0
--- /dev/null
+++ b/text_recognizer/networks/lenet.py
@@ -0,0 +1,68 @@
+"""Implementation of the LeNet network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class LeNet(nn.Module):
+    """LeNet network for character prediction."""
+
+    def __init__(
+        self,
+        channels: Tuple[int, ...] = (1, 32, 64),
+        kernel_sizes: Tuple[int, ...] = (3, 3, 2),
+        hidden_size: Tuple[int, ...] = (9216, 128),
+        dropout_rate: float = 0.2,
+        num_classes: int = 10,
+        activation_fn: Optional[str] = "relu",
+    ) -> None:
+        """Initialization of the LeNet network.
+
+        Args:
+            channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+            kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+            hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
+                Defaults to (9216, 128).
+            dropout_rate (float): The dropout rate. Defaults to 0.2.
+            num_classes (int): Number of classes. Defaults to 10.
+            activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+        """
+        super().__init__()
+
+        activation_fn = activation_function(activation_fn)
+
+        self.layers = [
+            nn.Conv2d(
+                in_channels=channels[0],
+                out_channels=channels[1],
+                kernel_size=kernel_sizes[0],
+            ),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=channels[1],
+                out_channels=channels[2],
+                kernel_size=kernel_sizes[1],
+            ),
+            activation_fn,
+            nn.MaxPool2d(kernel_sizes[2]),
+            nn.Dropout(p=dropout_rate),
+            Rearrange("b c h w -> b (c h w)"),
+            nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
+            activation_fn,
+            nn.Dropout(p=dropout_rate),
+            nn.Linear(in_features=hidden_size[1], out_features=num_classes),
+        ]
+
+        self.layers = nn.Sequential(*self.layers)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """The feedforward pass."""
+        # If batch dimenstion is missing, it needs to be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.layers(x)
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
new file mode 100644
index 0000000..b489264
--- /dev/null
+++ b/text_recognizer/networks/loss/__init__.py
@@ -0,0 +1,2 @@
+"""Loss module."""
+from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py
new file mode 100644
index 0000000..cf9fa0d
--- /dev/null
+++ b/text_recognizer/networks/loss/loss.py
@@ -0,0 +1,69 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
+from torch import nn
+from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
+
+
+class EmbeddingLoss:
+    """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+        self.distance = distances.CosineSimilarity()
+        self.reducer = reducers.ThresholdReducer(low=0)
+        self.loss_fn = losses.TripletMarginLoss(
+            margin=margin, distance=self.distance, reducer=self.reducer
+        )
+        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+        """Computes the metric loss for the embeddings based on their labels.
+
+        Args:
+            embeddings (Tensor): The laten vectors encoded by the network.
+            labels (Tensor): Labels of the embeddings.
+
+        Returns:
+            Tensor: The metric loss for the embeddings.
+
+        """
+        hard_pairs = self.miner(embeddings, labels)
+        loss = self.loss_fn(embeddings, labels, hard_pairs)
+        return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+    """Label smoothing loss function."""
+
+    def __init__(
+        self,
+        classes: int,
+        smoothing: float = 0.0,
+        ignore_index: int = None,
+        dim: int = -1,
+    ) -> None:
+        super().__init__()
+        self.confidence = 1.0 - smoothing
+        self.smoothing = smoothing
+        self.ignore_index = ignore_index
+        self.cls = classes
+        self.dim = dim
+
+    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+        """Calculates the loss."""
+        pred = pred.log_softmax(dim=self.dim)
+        with torch.no_grad():
+            # true_dist = pred.data.clone()
+            true_dist = torch.zeros_like(pred)
+            true_dist.fill_(self.smoothing / (self.cls - 1))
+            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+            if self.ignore_index is not None:
+                true_dist[:, self.ignore_index] = 0
+                mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+                if mask.dim() > 0:
+                    true_dist.index_fill_(0, mask.squeeze(), 0.0)
+        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py
new file mode 100644
index 0000000..2605731
--- /dev/null
+++ b/text_recognizer/networks/metrics.py
@@ -0,0 +1,123 @@
+"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
+import Levenshtein as Lev
+import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
+
+
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
+    """Computes the accuracy.
+
+    Args:
+        outputs (Tensor): The output from the network.
+        labels (Tensor): Ground truth labels.
+        pad_index (int): Padding index.
+
+    Returns:
+        float: The accuracy for the batch.
+
+    """
+
+    _, predicted = torch.max(outputs, dim=-1)
+
+    # Mask out the pad tokens
+    mask = labels != pad_index
+
+    predicted *= mask
+    labels *= mask
+
+    acc = (predicted == labels).sum().float() / labels.shape[0]
+    acc = acc.item()
+    return acc
+
+
+def cer(
+    outputs: Tensor,
+    targets: Tensor,
+    batch_size: Optional[int] = None,
+    blank_label: Optional[int] = int,
+) -> float:
+    """Computes the character error rate.
+
+    Args:
+        outputs (Tensor): The output from the network.
+        targets (Tensor): Ground truth labels.
+        batch_size (Optional[int]): Batch size if target and output has been flattend.
+        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
+
+    Returns:
+        float: The cer for the batch.
+
+    """
+    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+        targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
+    target_lengths = torch.full(
+        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+    )
+    decoded_predictions, decoded_targets = greedy_decoder(
+        outputs, targets, target_lengths, blank_label=blank_label,
+    )
+
+    lev_dist = 0
+
+    for prediction, target in zip(decoded_predictions, decoded_targets):
+        prediction = "".join(prediction)
+        target = "".join(target)
+        prediction, target = (
+            prediction.replace(" ", ""),
+            target.replace(" ", ""),
+        )
+        lev_dist += Lev.distance(prediction, target)
+    return lev_dist / len(decoded_predictions)
+
+
+def wer(
+    outputs: Tensor,
+    targets: Tensor,
+    batch_size: Optional[int] = None,
+    blank_label: Optional[int] = int,
+) -> float:
+    """Computes the Word error rate.
+
+    Args:
+        outputs (Tensor): The output from the network.
+        targets (Tensor): Ground truth labels.
+        batch_size (optional[int]): Batch size if target and output has been flattend.
+        blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
+
+    Returns:
+        float: The wer for the batch.
+
+    """
+    if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+        targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+        outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
+    target_lengths = torch.full(
+        size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+    )
+    decoded_predictions, decoded_targets = greedy_decoder(
+        outputs, targets, target_lengths, blank_label=blank_label,
+    )
+
+    lev_dist = 0
+
+    for prediction, target in zip(decoded_predictions, decoded_targets):
+        prediction = "".join(prediction)
+        target = "".join(target)
+
+        b = set(prediction.split() + target.split())
+        word2char = dict(zip(b, range(len(b))))
+
+        w1 = [chr(word2char[w]) for w in prediction.split()]
+        w2 = [chr(word2char[w]) for w in target.split()]
+
+        lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+    return lev_dist / len(decoded_predictions)
diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py
new file mode 100644
index 0000000..1101912
--- /dev/null
+++ b/text_recognizer/networks/mlp.py
@@ -0,0 +1,73 @@
+"""Defines the MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class MLP(nn.Module):
+    """Multi layered perceptron network."""
+
+    def __init__(
+        self,
+        input_size: int = 784,
+        num_classes: int = 10,
+        hidden_size: Union[int, List] = 128,
+        num_layers: int = 3,
+        dropout_rate: float = 0.2,
+        activation_fn: str = "relu",
+    ) -> None:
+        """Initialization of the MLP network.
+
+        Args:
+            input_size (int): The input shape of the network. Defaults to 784.
+            num_classes (int): Number of classes in the dataset. Defaults to 10.
+            hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+            num_layers (int): The number of hidden layers. Defaults to 3.
+            dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
+            activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+                relu.
+
+        """
+        super().__init__()
+
+        activation_fn = activation_function(activation_fn)
+
+        if isinstance(hidden_size, int):
+            hidden_size = [hidden_size] * num_layers
+
+        self.layers = [
+            Rearrange("b c h w -> b (c h w)"),
+            nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+            activation_fn,
+        ]
+
+        for i in range(num_layers - 1):
+            self.layers += [
+                nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]),
+                activation_fn,
+            ]
+
+            if dropout_rate:
+                self.layers.append(nn.Dropout(p=dropout_rate))
+
+        self.layers.append(
+            nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+        )
+
+        self.layers = nn.Sequential(*self.layers)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """The feedforward pass."""
+        # If batch dimenstion is missing, it needs to be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.layers(x)
+
+    @property
+    def __name__(self) -> str:
+        """Returns the name of the network."""
+        return "mlp"
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
new file mode 100644
index 0000000..c33f419
--- /dev/null
+++ b/text_recognizer/networks/residual_network.py
@@ -0,0 +1,310 @@
+"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+    """Convolution with auto padding based on kernel size."""
+
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+    """3x3 convolution with batch norm."""
+    conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+    return nn.Sequential(
+        conv3x3(in_channels, out_channels, *args, **kwargs),
+        nn.BatchNorm2d(out_channels),
+    )
+
+
+class IdentityBlock(nn.Module):
+    """Residual with identity block."""
+
+    def __init__(
+        self, in_channels: int, out_channels: int, activation: str = "relu"
+    ) -> None:
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.blocks = nn.Identity()
+        self.activation_fn = activation_function(activation)
+        self.shortcut = nn.Identity()
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        residual = x
+        if self.apply_shortcut:
+            residual = self.shortcut(x)
+        x = self.blocks(x)
+        x += residual
+        x = self.activation_fn(x)
+        return x
+
+    @property
+    def apply_shortcut(self) -> bool:
+        """Check if shortcut should be applied."""
+        return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+    """Residual with nonlinear shortcut."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        expansion: int = 1,
+        downsampling: int = 1,
+        *args,
+        **kwargs
+    ) -> None:
+        """Short summary.
+
+        Args:
+            in_channels (int): Number of in channels.
+            out_channels (int): umber of out channels.
+            expansion (int): Expansion factor of the out channels. Defaults to 1.
+            downsampling (int): Downsampling factor used in stride. Defaults to 1.
+            *args (type): Extra arguments.
+            **kwargs (type): Extra key value arguments.
+
+        """
+        super().__init__(in_channels, out_channels, *args, **kwargs)
+        self.expansion = expansion
+        self.downsampling = downsampling
+
+        self.shortcut = (
+            nn.Sequential(
+                nn.Conv2d(
+                    in_channels=self.in_channels,
+                    out_channels=self.expanded_channels,
+                    kernel_size=1,
+                    stride=self.downsampling,
+                    bias=False,
+                ),
+                nn.BatchNorm2d(self.expanded_channels),
+            )
+            if self.apply_shortcut
+            else None
+        )
+
+    @property
+    def expanded_channels(self) -> int:
+        """Computes the expanded output channels."""
+        return self.out_channels * self.expansion
+
+    @property
+    def apply_shortcut(self) -> bool:
+        """Check if shortcut should be applied."""
+        return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+    """Basic ResNet block."""
+
+    expansion = 1
+
+    def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+        super().__init__(in_channels, out_channels, *args, **kwargs)
+        self.blocks = nn.Sequential(
+            conv_bn(
+                in_channels=self.in_channels,
+                out_channels=self.out_channels,
+                bias=False,
+                stride=self.downsampling,
+            ),
+            self.activation_fn,
+            conv_bn(
+                in_channels=self.out_channels,
+                out_channels=self.expanded_channels,
+                bias=False,
+            ),
+        )
+
+
+class BottleNeckBlock(ResidualBlock):
+    """Bottleneck block to increase depth while minimizing parameter size."""
+
+    expansion = 4
+
+    def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+        super().__init__(in_channels, out_channels, *args, **kwargs)
+        self.blocks = nn.Sequential(
+            conv_bn(
+                in_channels=self.in_channels,
+                out_channels=self.out_channels,
+                kernel_size=1,
+            ),
+            self.activation_fn,
+            conv_bn(
+                in_channels=self.out_channels,
+                out_channels=self.out_channels,
+                kernel_size=3,
+                stride=self.downsampling,
+            ),
+            self.activation_fn,
+            conv_bn(
+                in_channels=self.out_channels,
+                out_channels=self.expanded_channels,
+                kernel_size=1,
+            ),
+        )
+
+
+class ResidualLayer(nn.Module):
+    """ResNet layer."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        block: BasicBlock = BasicBlock,
+        num_blocks: int = 1,
+        *args,
+        **kwargs
+    ) -> None:
+        super().__init__()
+        downsampling = 2 if in_channels != out_channels else 1
+        self.blocks = nn.Sequential(
+            block(
+                in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+            ),
+            *[
+                block(
+                    out_channels * block.expansion,
+                    out_channels,
+                    downsampling=1,
+                    *args,
+                    **kwargs
+                )
+                for _ in range(num_blocks - 1)
+            ]
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        x = self.blocks(x)
+        return x
+
+
+class ResidualNetworkEncoder(nn.Module):
+    """Encoder network."""
+
+    def __init__(
+        self,
+        in_channels: int = 1,
+        block_sizes: Union[int, List[int]] = (32, 64),
+        depths: Union[int, List[int]] = (2, 2),
+        activation: str = "relu",
+        block: Type[nn.Module] = BasicBlock,
+        levels: int = 1,
+        *args,
+        **kwargs
+    ) -> None:
+        super().__init__()
+        self.block_sizes = (
+            block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+        )
+        self.depths = depths if isinstance(depths, list) else [depths] * levels
+        self.activation = activation
+        self.gate = nn.Sequential(
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=self.block_sizes[0],
+                kernel_size=7,
+                stride=2,
+                padding=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(self.block_sizes[0]),
+            activation_function(self.activation),
+            # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+        )
+
+        self.blocks = self._configure_blocks(block)
+
+    def _configure_blocks(
+        self, block: Type[nn.Module], *args, **kwargs
+    ) -> nn.Sequential:
+        channels = [self.block_sizes[0]] + list(
+            zip(self.block_sizes, self.block_sizes[1:])
+        )
+        blocks = [
+            ResidualLayer(
+                in_channels=channels[0],
+                out_channels=channels[0],
+                num_blocks=self.depths[0],
+                block=block,
+                activation=self.activation,
+                *args,
+                **kwargs
+            )
+        ]
+        blocks += [
+            ResidualLayer(
+                in_channels=in_channels * block.expansion,
+                out_channels=out_channels,
+                num_blocks=num_blocks,
+                block=block,
+                activation=self.activation,
+                *args,
+                **kwargs
+            )
+            for (in_channels, out_channels), num_blocks in zip(
+                channels[1:], self.depths[1:]
+            )
+        ]
+
+        return nn.Sequential(*blocks)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        # If batch dimenstion is missing, it needs to be added.
+        if len(x.shape) == 3:
+            x = x.unsqueeze(0)
+        x = self.gate(x)
+        x = self.blocks(x)
+        return x
+
+
+class ResidualNetworkDecoder(nn.Module):
+    """Classification head."""
+
+    def __init__(self, in_features: int, num_classes: int = 80) -> None:
+        super().__init__()
+        self.decoder = nn.Sequential(
+            Reduce("b c h w -> b c", "mean"),
+            nn.Linear(in_features=in_features, out_features=num_classes),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+    """Full residual network."""
+
+    def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+        super().__init__()
+        self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+        self.decoder = ResidualNetworkDecoder(
+            in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+            num_classes=num_classes,
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        x = self.encoder(x)
+        x = self.decoder(x)
+        return x
diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py
new file mode 100644
index 0000000..e9d216f
--- /dev/null
+++ b/text_recognizer/networks/stn.py
@@ -0,0 +1,44 @@
+"""Spatial Transformer Network."""
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class SpatialTransformerNetwork(nn.Module):
+    """A network with differentiable attention.
+
+    Network that learns how to perform spatial transformations on the input image in order to enhance the
+    geometric invariance of the model.
+
+    # TODO: add arguments to make it more general.
+
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        # Initialize the identity transformation and its weights and biases.
+        linear = nn.Linear(32, 3 * 2)
+        linear.weight.data.zero_()
+        linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
+
+        self.theta = nn.Sequential(
+            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
+            nn.MaxPool2d(kernel_size=2, stride=2),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
+            nn.MaxPool2d(kernel_size=2, stride=2),
+            nn.ReLU(inplace=True),
+            Rearrange("b c h w -> b (c h w)", h=3, w=3),
+            nn.Linear(in_features=10 * 3 * 3, out_features=32),
+            nn.ReLU(inplace=True),
+            linear,
+            Rearrange("b (row col) -> b row col", row=2, col=3),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """The spatial transformation."""
+        grid = F.affine_grid(self.theta(x), x.shape)
+        return F.grid_sample(x, grid, align_corners=False)
diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..8c19a01
--- /dev/null
+++ b/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,3 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
+from .transducer import load_transducer_loss, Transducer
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..5fb8ba9
--- /dev/null
+++ b/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,208 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+    https://arxiv.org/abs/1904.02619
+    https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+    https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+    """Internal block of a 2D TDSC network."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        img_depth: int,
+        kernel_size: Tuple[int],
+        dropout_rate: float,
+    ) -> None:
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.img_depth = img_depth
+        self.kernel_size = kernel_size
+        self.dropout_rate = dropout_rate
+        self.fc_dim = in_channels * img_depth
+
+        # Network placeholders.
+        self.conv = None
+        self.mlp = None
+        self.instance_norm = None
+
+        self._build_block()
+
+    def _build_block(self) -> None:
+        # Convolutional block.
+        self.conv = nn.Sequential(
+            nn.Conv3d(
+                in_channels=self.in_channels,
+                out_channels=self.in_channels,
+                kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+                padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+            ),
+            nn.ReLU(inplace=True),
+            nn.Dropout(self.dropout_rate),
+        )
+
+        # MLP block.
+        self.mlp = nn.Sequential(
+            nn.Linear(self.fc_dim, self.fc_dim),
+            nn.ReLU(inplace=True),
+            nn.Dropout(self.dropout_rate),
+            nn.Linear(self.fc_dim, self.fc_dim),
+            nn.Dropout(self.dropout_rate),
+        )
+
+        # Instance norm.
+        self.instance_norm = nn.ModuleList(
+            [
+                nn.InstanceNorm2d(self.fc_dim, affine=True),
+                nn.InstanceNorm2d(self.fc_dim, affine=True),
+            ]
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass.
+
+        Args:
+            x (Tensor): Input tensor.
+
+        Shape:
+            - x: :math: `(B, CD, H, W)`
+
+        Returns:
+            Tensor: Output tensor.
+
+        """
+        B, CD, H, W = x.shape
+        C, D = self.in_channels, self.img_depth
+        residual = x
+        x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+        x = self.conv(x)
+        x = rearrange(x, "b c d h w -> b (c d) h w")
+        x += residual
+
+        x = self.instance_norm[0](x)
+
+        x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+        x + self.instance_norm[1](x)
+
+        # Output shape: [B, CD, H, W]
+        return x
+
+
+class TDS2d(nn.Module):
+    """TDS Netowrk.
+
+    Structure is the following:
+        Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+    """
+
+    def __init__(
+        self,
+        input_dim: int,
+        output_dim: int,
+        depth: int,
+        tds_groups: Tuple[int],
+        kernel_size: Tuple[int],
+        dropout_rate: float,
+        in_channels: int = 1,
+    ) -> None:
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.depth = depth
+        self.tds_groups = tds_groups
+        self.kernel_size = kernel_size
+        self.dropout_rate = dropout_rate
+
+        self.tds = None
+        self.fc = None
+
+        self._build_network()
+
+    def _build_network(self) -> None:
+        in_channels = self.in_channels
+        modules = []
+        stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+        if self.input_dim % stride_h:
+            raise RuntimeError(
+                f"Image height not divisible by total stride {stride_h}."
+            )
+
+        for tds_group in self.tds_groups:
+            # Add downsample layer.
+            out_channels = self.depth * tds_group["channels"]
+            modules.extend(
+                [
+                    nn.Conv2d(
+                        in_channels=in_channels,
+                        out_channels=out_channels,
+                        kernel_size=self.kernel_size,
+                        padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+                        stride=tds_group["stride"],
+                    ),
+                    nn.ReLU(inplace=True),
+                    nn.Dropout(self.dropout_rate),
+                    nn.InstanceNorm2d(out_channels, affine=True),
+                ]
+            )
+
+            for _ in range(tds_group["num_blocks"]):
+                modules.append(
+                    TDSBlock2d(
+                        tds_group["channels"],
+                        self.depth,
+                        self.kernel_size,
+                        self.dropout_rate,
+                    )
+                )
+
+            in_channels = out_channels
+
+        self.tds = nn.Sequential(*modules)
+        self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass.
+
+        Args:
+            x (Tensor): Input tensor.
+
+        Shape:
+            - x: :math: `(B, H, W)`
+
+        Returns:
+            Tensor: Output tensor.
+
+        """
+        if len(x.shape) == 4:
+            x = x.squeeze(1)  # Squeeze the channel dim away.
+
+        B, H, W = x.shape
+        x = rearrange(
+            x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+        )
+        x = self.tds(x)
+
+        # x shape: [B, C, H, W]
+        x = rearrange(x, "b c h w -> b w (c h)")
+
+        return self.fc(x)
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+    def test_viterbi(self):
+        T = 5
+        N = 4
+        B = 2
+
+        # fmt: off
+        emissions1 = torch.tensor((
+            0, 4, 0, 1,
+            0, 2, 1, 1,
+            0, 0, 0, 2,
+            0, 0, 0, 2,
+            8, 0, 0, 2,
+            ),
+            dtype=torch.float,
+        ).view(T, N)
+        emissions2 = torch.tensor((
+            0, 2, 1, 7,
+            0, 2, 9, 1,
+            0, 0, 0, 2,
+            0, 0, 5, 2,
+            1, 0, 0, 2,
+            ),
+            dtype=torch.float,
+        ).view(T, N)
+        # fmt: on
+
+        # Test without blank:
+        labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+        transducer = Transducer(
+            tokens=["a", "b", "c", "d"],
+            graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+            blank="none",
+        )
+        emissions = torch.stack([emissions1, emissions2], dim=0)
+        predictions = transducer.viterbi(emissions)
+        self.assertEqual([p.tolist() for p in predictions], labels)
+
+        # Test with blank without repeats:
+        labels = [[1, 0], [2, 2]]
+        transducer = Transducer(
+            tokens=["a", "b", "c"],
+            graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+            blank="optional",
+            allow_repeats=False,
+        )
+        emissions = torch.stack([emissions1, emissions2], dim=0)
+        predictions = transducer.viterbi(emissions)
+        self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
new file mode 100644
index 0000000..d7e3d08
--- /dev/null
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -0,0 +1,410 @@
+"""Transducer and the transducer loss function.py
+
+Stolen from:
+    https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
+
+"""
+from pathlib import Path
+import itertools
+from typing import Dict, List, Optional, Union, Tuple
+
+from loguru import logger
+import gtn
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+def make_scalar_graph(weight) -> gtn.Graph:
+    scalar = gtn.Graph()
+    scalar.add_node(True)
+    scalar.add_node(False, True)
+    scalar.add_arc(0, 1, 0, 0, weight)
+    return scalar
+
+
+def make_chain_graph(sequence) -> gtn.Graph:
+    graph = gtn.Graph(False)
+    graph.add_node(True)
+    for i, s in enumerate(sequence):
+        graph.add_node(False, i == (len(sequence) - 1))
+        graph.add_arc(i, i + 1, s)
+    return graph
+
+
+def make_transitions_graph(
+    ngram: int, num_tokens: int, calc_grad: bool = False
+) -> gtn.Graph:
+    transitions = gtn.Graph(calc_grad)
+    transitions.add_node(True, ngram == 1)
+
+    state_map = {(): 0}
+
+    # First build transitions which include <s>:
+    for n in range(1, ngram):
+        for state in itertools.product(range(num_tokens), repeat=n):
+            in_idx = state_map[state[:-1]]
+            out_idx = transitions.add_node(False, ngram == 1)
+            state_map[state] = out_idx
+            transitions.add_arc(in_idx, out_idx, state[-1])
+
+    for state in itertools.product(range(num_tokens), repeat=ngram):
+        state_idx = state_map[state[:-1]]
+        new_state_idx = state_map[state[1:]]
+        # p(state[-1] | state[:-1])
+        transitions.add_arc(state_idx, new_state_idx, state[-1])
+
+    if ngram > 1:
+        # Build transitions which include </s>:
+        end_idx = transitions.add_node(False, True)
+        for in_idx in range(end_idx):
+            transitions.add_arc(in_idx, end_idx, gtn.epsilon)
+
+    return transitions
+
+
+def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+    """Constructs a graph which transduces letters to word pieces."""
+    graph = gtn.Graph(False)
+    graph.add_node(True, True)
+    for i, wp in enumerate(word_pieces):
+        prev = 0
+        for l in wp[:-1]:
+            n = graph.add_node()
+            graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
+            prev = n
+        graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+    graph.arc_sort()
+    return graph
+
+
+def make_token_graph(
+    token_list: List, blank: str = "none", allow_repeats: bool = True
+) -> gtn.Graph:
+    """Constructs a graph with all the individual token transition models."""
+    if not allow_repeats and blank != "optional":
+        raise ValueError("Must use blank='optional' if disallowing repeats.")
+
+    ntoks = len(token_list)
+    graph = gtn.Graph(False)
+
+    # Creating nodes
+    graph.add_node(True, True)
+    for i in range(ntoks):
+        # We can consume one or more consecutive word
+        # pieces for each emission:
+        # E.g. [ab, ab, ab] transduces to [ab]
+        graph.add_node(False, blank != "forced")
+
+    if blank != "none":
+        graph.add_node()
+
+    # Creating arcs
+    if blank != "none":
+        # Blank index is assumed to be last (ntoks)
+        graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
+        graph.add_arc(ntoks + 1, 0, gtn.epsilon)
+
+    for i in range(ntoks):
+        graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
+        graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
+
+        if allow_repeats:
+            if blank == "forced":
+                # Allow transitions from token to blank only
+                graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+            else:
+                # Allow transition from token to blank and all other tokens
+                graph.add_arc(i + 1, 0, gtn.epsilon)
+
+        else:
+            # allow transitions to blank and all other tokens except the same token
+            graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+            for j in range(ntoks):
+                if i != j:
+                    graph.add_arc(i + 1, j + 1, j, j)
+
+    return graph
+
+
+class TransducerLossFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        inputs,
+        targets,
+        tokens,
+        lexicon,
+        transition_params=None,
+        transitions=None,
+        reduction="none",
+    ) -> Tensor:
+        B, T, C = inputs.shape
+
+        losses = [None] * B
+        emissions_graphs = [None] * B
+
+        if transitions is not None:
+            if transition_params is None:
+                raise ValueError("Specified transitions, but not transition params.")
+
+            cpu_data = transition_params.cpu().contiguous()
+            transitions.set_weights(cpu_data.data_ptr())
+            transitions.calc_grad = transition_params.requires_grad
+            transitions.zero_grad()
+
+        def process(b: int) -> None:
+            # Create emission graph:
+            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
+            cpu_data = inputs[b].cpu().contiguous()
+            emissions.set_weights(cpu_data.data_ptr())
+            target = make_chain_graph(targets[b])
+            target.arc_sort(True)
+
+            # Create token tot grapheme decomposition graph
+            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
+            tokens_target.arc_sort()
+
+            # Create alignment graph:
+            aligments = gtn.project_input(
+                gtn.remove(gtn.compose(tokens, tokens_target))
+            )
+            aligments.arc_sort()
+
+            # Add transitions scores:
+            if transitions is not None:
+                aligments = gtn.intersect(transitions, aligments)
+                aligments.arc_sort()
+
+            loss = gtn.forward_score(gtn.intersect(emissions, aligments))
+
+            # Normalize if needed:
+            if transitions is not None:
+                norm = gtn.forward_score(gtn.intersect(emissions, transitions))
+                loss = gtn.subtract(loss, norm)
+
+            losses[b] = gtn.negate(loss)
+
+            # Save for backward:
+            if emissions.calc_grad:
+                emissions_graphs[b] = emissions
+
+        gtn.parallel_for(process, range(B))
+
+        ctx.graphs = (losses, emissions_graphs, transitions)
+        ctx.input_shape = inputs.shape
+
+        # Optionally reduce by target length
+        if reduction == "mean":
+            scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
+        else:
+            scales = [1.0] * B
+
+        ctx.scales = scales
+
+        loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
+        return torch.mean(loss.to(inputs.device))
+
+    @staticmethod
+    def backward(ctx, grad_output) -> Tuple:
+        losses, emissions_graphs, transitions = ctx.graphs
+        scales = ctx.scales
+
+        B, T, C = ctx.input_shape
+        calc_emissions = ctx.needs_input_grad[0]
+        input_grad = torch.empty((B, T, C)) if calc_emissions else None
+
+        def process(b: int) -> None:
+            scale = make_scalar_graph(scales[b])
+            gtn.backward(losses[b], scale)
+            emissions = emissions_graphs[b]
+            if calc_emissions:
+                grad = emissions.grad().weights_to_numpy()
+                input_grad[b] = torch.tensor(grad).view(1, T, C)
+
+        gtn.parallel_for(process, range(B))
+
+        if calc_emissions:
+            input_grad = input_grad.to(grad_output.device)
+            input_grad *= grad_output / B
+
+        if ctx.needs_input_grad[4]:
+            grad = transitions.grad().weights_to_numpy()
+            transition_grad = torch.tensor(grad).to(grad_output.device)
+            transition_grad *= grad_output / B
+        else:
+            transition_grad = None
+
+        return (
+            input_grad,
+            None,  # target
+            None,  # tokens
+            None,  # lexicon
+            transition_grad,  # transition params
+            None,  # transitions graph
+            None,
+        )
+
+
+TransducerLoss = TransducerLossFunction.apply
+
+
+class Transducer(nn.Module):
+    def __init__(
+        self,
+        tokens: List,
+        graphemes_to_idx: Dict,
+        ngram: int = 0,
+        transitions: str = None,
+        blank: str = "none",
+        allow_repeats: bool = True,
+        reduction: str = "none",
+    ) -> None:
+        """A generic transducer loss function.
+
+        Args:
+            tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
+                representing the output tokens of the model (e.g. letters,
+                word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
+                could be a list of sub-word tokens.
+            graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
+                "a", "b", ..) to their corresponding integer index.
+            ngram (int) : Order of the token-level transition model. If `ngram=0`
+                then no transition model is used.
+            blank (string) : Specifies the usage of blank token
+                'none' - do not use blank token
+                'optional' - allow an optional blank inbetween tokens
+                'forced' - force a blank inbetween tokens (also referred to as garbage token)
+            allow_repeats (boolean) : If false, then we don't allow paths with
+                consecutive tokens in the alignment graph. This keeps the graph
+                unambiguous in the sense that the same input cannot transduce to
+                different outputs.
+        """
+        super().__init__()
+        if blank not in ["optional", "forced", "none"]:
+            raise ValueError(
+                "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
+            )
+        self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
+        self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+        self.ngram = ngram
+        if ngram > 0 and transitions is not None:
+            raise ValueError("Only one of ngram and transitions may be specified")
+
+        if ngram > 0:
+            transitions = make_transitions_graph(
+                ngram, len(tokens) + int(blank != "none"), True
+            )
+
+        if transitions is not None:
+            self.transitions = transitions
+            self.transitions.arc_sort()
+            self.transitions_params = nn.Parameter(
+                torch.zeros(self.transitions.num_arcs())
+            )
+        else:
+            self.transitions = None
+            self.transitions_params = None
+        self.reduction = reduction
+
+    def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
+        TransducerLoss(
+            inputs,
+            targets,
+            self.tokens,
+            self.lexicon,
+            self.transitions_params,
+            self.transitions,
+            self.reduction,
+        )
+
+    def viterbi(self, outputs: Tensor) -> List[Tensor]:
+        B, T, C = outputs.shape
+
+        if self.transitions is not None:
+            cpu_data = self.transition_params.cpu().contiguous()
+            self.transitions.set_weights(cpu_data.data_ptr())
+            self.transitions.calc_grad = False
+
+        self.tokens.arc_sort()
+
+        paths = [None] * B
+
+        def process(b: int) -> None:
+            emissions = gtn.linear_graph(T, C, False)
+            cpu_data = outputs[b].cpu().contiguous()
+            emissions.set_weights(cpu_data.data_ptr())
+
+            if self.transitions is not None:
+                full_graph = gtn.intersect(emissions, self.transitions)
+            else:
+                full_graph = emissions
+
+            # Find the best path and remove back-off arcs:
+            path = gtn.remove(gtn.viterbi_path(full_graph))
+
+            # Left compose the viterbi path with the "aligment to token"
+            # transducer to get the outputs:
+            path = gtn.compose(path, self.tokens)
+
+            # When there are ambiguous paths (allow_repeats is true), we take
+            # the shortest:
+            path = gtn.viterbi_path(path)
+            path = gtn.remove(gtn.project_output(path))
+            paths[b] = path.labels_to_list()
+
+        gtn.parallel_for(process, range(B))
+        predictions = [torch.IntTensor(path) for path in paths]
+        return predictions
+
+
+def load_transducer_loss(
+    num_features: int,
+    ngram: int,
+    tokens: str,
+    lexicon: str,
+    transitions: str,
+    blank: str,
+    allow_repeats: bool,
+    prepend_wordsep: bool = False,
+    use_words: bool = False,
+    data_dir: Optional[Union[str, Path]] = None,
+    reduction: str = "mean",
+) -> Tuple[Transducer, int]:
+    if data_dir is None:
+        data_dir = (
+            Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
+        )
+        logger.debug(f"Using data dir: {data_dir}")
+        if not data_dir.exists():
+            raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+    else:
+        data_dir = Path(data_dir)
+    processed_path = (
+        Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
+    )
+    tokens_path = processed_path / tokens
+    lexicon_path = processed_path / lexicon
+
+    if transitions is not None:
+        transitions = gtn.load(str(processed_path / transitions))
+
+    preprocessor = Preprocessor(
+        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+    )
+
+    num_tokens = preprocessor.num_tokens
+
+    criterion = Transducer(
+        preprocessor.tokens,
+        preprocessor.graphemes_to_index,
+        ngram=ngram,
+        transitions=transitions,
+        blank=blank,
+        allow_repeats=allow_repeats,
+        reduction=reduction,
+    )
+
+    return criterion, num_tokens + int(blank != "none")
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..9febc88
--- /dev/null
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+    """Implementation of multihead attention."""
+
+    def __init__(
+        self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+    ) -> None:
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.fc_q = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_k = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_v = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+        self._init_weights()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def _init_weights(self) -> None:
+        nn.init.normal_(
+            self.fc_q.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.normal_(
+            self.fc_k.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.normal_(
+            self.fc_v.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.xavier_normal_(self.fc_out.weight)
+
+    def scaled_dot_product_attention(
+        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+    ) -> Tensor:
+        """Calculates the scaled dot product attention."""
+
+        # Compute the energy.
+        energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+            query.shape[-1]
+        )
+
+        # If we have a mask for padding some inputs.
+        if mask is not None:
+            energy = energy.masked_fill(mask == 0, -np.inf)
+
+        # Compute the attention from the energy.
+        attention = torch.softmax(energy, dim=3)
+
+        out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+        out = rearrange(out, "b head l v -> b l (head v)")
+        return out, attention
+
+    def forward(
+        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+    ) -> Tuple[Tensor, Tensor]:
+        """Forward pass for computing the multihead attention."""
+        # Get the query, key, and value tensor.
+        query = rearrange(
+            self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+        )
+        key = rearrange(
+            self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+        )
+        value = rearrange(
+            self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+        )
+
+        out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+        out = self.fc_out(out)
+        out = self.dropout(out)
+        return out, attention
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+    """Encodes a sense of distance or time for transformer networks."""
+
+    def __init__(
+        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+    ) -> None:
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout_rate)
+        self.max_len = max_len
+
+        pe = torch.zeros(max_len, hidden_dim)
+        position = torch.arange(0, max_len).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+        )
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        self.register_buffer("pe", pe)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Encodes the tensor with a postional embedding."""
+        x = x + self.pe[:, : x.shape[1]]
+        return self.dropout(x)
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..dd180c4
--- /dev/null
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,264 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+class GEGLU(nn.Module):
+    """GLU activation for improving feedforward activations."""
+
+    def __init__(self, dim_in: int, dim_out: int) -> None:
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward propagation."""
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+    """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+    def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+        super().__init__()
+        self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+        return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+    def __init__(
+        self,
+        hidden_dim: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+
+        in_projection = (
+            nn.Sequential(
+                nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+            )
+            if activation != "glu"
+            else GEGLU(hidden_dim, expansion_dim)
+        )
+
+        self.layer = nn.Sequential(
+            in_projection,
+            nn.Dropout(p=dropout_rate),
+            nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+    """Transfomer encoding layer."""
+
+    def __init__(
+        self,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+        self.cnn = _ConvolutionalLayer(
+            hidden_dim, expansion_dim, dropout_rate, activation
+        )
+        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+        """Forward pass through the encoder."""
+        # First block.
+        # Multi head attention.
+        out, _ = self.self_attention(src, src, src, mask)
+
+        # Add & norm.
+        out = self.block1(out, src)
+
+        # Second block.
+        # Apply 1D-convolution.
+        cnn_out = self.cnn(out)
+
+        # Add & norm.
+        out = self.block2(cnn_out, out)
+
+        return out
+
+
+class Encoder(nn.Module):
+    """Transfomer encoder module."""
+
+    def __init__(
+        self,
+        num_layers: int,
+        encoder_layer: Type[nn.Module],
+        norm: Optional[Type[nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.norm = norm
+
+    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+        """Forward pass through all encoder layers."""
+        for layer in self.layers:
+            src = layer(src, src_mask)
+
+        if self.norm is not None:
+            src = self.norm(src)
+
+        return src
+
+
+class DecoderLayer(nn.Module):
+    """Transfomer decoder layer."""
+
+    def __init__(
+        self,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float = 0.0,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+        self.multihead_attention = MultiHeadAttention(
+            hidden_dim, num_heads, dropout_rate
+        )
+        self.cnn = _ConvolutionalLayer(
+            hidden_dim, expansion_dim, dropout_rate, activation
+        )
+        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+    def forward(
+        self,
+        trg: Tensor,
+        memory: Tensor,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass of the layer."""
+        out, _ = self.self_attention(trg, trg, trg, trg_mask)
+        trg = self.block1(out, trg)
+
+        out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+        trg = self.block2(out, trg)
+
+        out = self.cnn(trg)
+        out = self.block3(out, trg)
+
+        return out
+
+
+class Decoder(nn.Module):
+    """Transfomer decoder module."""
+
+    def __init__(
+        self,
+        decoder_layer: Type[nn.Module],
+        num_layers: int,
+        norm: Optional[Type[nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(
+        self,
+        trg: Tensor,
+        memory: Tensor,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass through the decoder."""
+        for layer in self.layers:
+            trg = layer(trg, memory, trg_mask, memory_mask)
+
+        if self.norm is not None:
+            trg = self.norm(trg)
+
+        return trg
+
+
+class Transformer(nn.Module):
+    """Transformer network."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+
+        # Configure encoder.
+        encoder_norm = nn.LayerNorm(hidden_dim)
+        encoder_layer = EncoderLayer(
+            hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+        )
+        self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+        # Configure decoder.
+        decoder_norm = nn.LayerNorm(hidden_dim)
+        decoder_layer = DecoderLayer(
+            hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+        )
+        self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(
+        self,
+        src: Tensor,
+        trg: Tensor,
+        src_mask: Optional[Tensor] = None,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass through the transformer."""
+        if src.shape[0] != trg.shape[0]:
+            print(trg.shape)
+            raise RuntimeError("The batch size of the src and trg must be the same.")
+        if src.shape[2] != trg.shape[2]:
+            raise RuntimeError(
+                "The number of features for the src and trg must be the same."
+            )
+
+        memory = self.encoder(src, src_mask)
+        output = self.decoder(trg, memory, trg_mask, memory_mask)
+        return output
diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py
new file mode 100644
index 0000000..510910f
--- /dev/null
+++ b/text_recognizer/networks/unet.py
@@ -0,0 +1,255 @@
+"""UNet for segmentation."""
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _ConvBlock(nn.Module):
+    """Modified UNet convolutional block with dilation."""
+
+    def __init__(
+        self,
+        channels: List[int],
+        activation: str,
+        num_groups: int,
+        dropout_rate: float = 0.1,
+        kernel_size: int = 3,
+        dilation: int = 1,
+        padding: int = 0,
+    ) -> None:
+        super().__init__()
+        self.channels = channels
+        self.dropout_rate = dropout_rate
+        self.kernel_size = kernel_size
+        self.dilation = dilation
+        self.padding = padding
+        self.num_groups = num_groups
+        self.activation = activation_function(activation)
+        self.block = self._configure_block()
+        self.residual_conv = nn.Sequential(
+            nn.Conv2d(
+                self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
+            ),
+            self.activation,
+        )
+
+    def _configure_block(self) -> nn.Sequential:
+        block = []
+        for i in range(len(self.channels) - 1):
+            block += [
+                nn.Dropout(p=self.dropout_rate),
+                nn.GroupNorm(self.num_groups, self.channels[i]),
+                self.activation,
+                nn.Conv2d(
+                    self.channels[i],
+                    self.channels[i + 1],
+                    kernel_size=self.kernel_size,
+                    padding=self.padding,
+                    stride=1,
+                    dilation=self.dilation,
+                ),
+            ]
+
+        return nn.Sequential(*block)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Apply the convolutional block."""
+        residual = self.residual_conv(x)
+        return self.block(x) + residual
+
+
+class _DownSamplingBlock(nn.Module):
+    """Basic down sampling block."""
+
+    def __init__(
+        self,
+        channels: List[int],
+        activation: str,
+        num_groups: int,
+        pooling_kernel: Union[int, bool] = 2,
+        dropout_rate: float = 0.1,
+        kernel_size: int = 3,
+        dilation: int = 1,
+        padding: int = 0,
+    ) -> None:
+        super().__init__()
+        self.conv_block = _ConvBlock(
+            channels,
+            activation,
+            num_groups,
+            dropout_rate,
+            kernel_size,
+            dilation,
+            padding,
+        )
+        self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
+
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+        """Return the convolutional block output and a down sampled tensor."""
+        x = self.conv_block(x)
+        x_down = self.down_sampling(x) if self.down_sampling is not None else x
+
+        return x_down, x
+
+
+class _UpSamplingBlock(nn.Module):
+    """The upsampling block of the UNet."""
+
+    def __init__(
+        self,
+        channels: List[int],
+        activation: str,
+        num_groups: int,
+        scale_factor: int = 2,
+        dropout_rate: float = 0.1,
+        kernel_size: int = 3,
+        dilation: int = 1,
+        padding: int = 0,
+    ) -> None:
+        super().__init__()
+        self.conv_block = _ConvBlock(
+            channels,
+            activation,
+            num_groups,
+            dropout_rate,
+            kernel_size,
+            dilation,
+            padding,
+        )
+        self.up_sampling = nn.Upsample(
+            scale_factor=scale_factor, mode="bilinear", align_corners=True
+        )
+
+    def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor:
+        """Apply the up sampling and convolutional block."""
+        x = self.up_sampling(x)
+        if x_skip is not None:
+            x = torch.cat((x, x_skip), dim=1)
+        return self.conv_block(x)
+
+
+class UNet(nn.Module):
+    """UNet architecture."""
+
+    def __init__(
+        self,
+        in_channels: int = 1,
+        base_channels: int = 64,
+        num_classes: int = 3,
+        depth: int = 4,
+        activation: str = "relu",
+        num_groups: int = 8,
+        dropout_rate: float = 0.1,
+        pooling_kernel: int = 2,
+        scale_factor: int = 2,
+        kernel_size: Optional[List[int]] = None,
+        dilation: Optional[List[int]] = None,
+        padding: Optional[List[int]] = None,
+    ) -> None:
+        super().__init__()
+        self.depth = depth
+        self.num_groups = num_groups
+
+        if kernel_size is not None and dilation is not None and padding is not None:
+            if (
+                len(kernel_size) != depth
+                and len(dilation) != depth
+                and len(padding) != depth
+            ):
+                raise RuntimeError(
+                    "Length of convolutional parameters does not match the depth."
+                )
+            self.kernel_size = kernel_size
+            self.padding = padding
+            self.dilation = dilation
+
+        else:
+            self.kernel_size = [3] * depth
+            self.padding = [1] * depth
+            self.dilation = [1] * depth
+
+        self.dropout_rate = dropout_rate
+        self.conv = nn.Conv2d(
+            in_channels, base_channels, kernel_size=3, stride=1, padding=1
+        )
+
+        channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
+        self.encoder_blocks = self._configure_down_sampling_blocks(
+            channels, activation, pooling_kernel
+        )
+        self.decoder_blocks = self._configure_up_sampling_blocks(
+            channels, activation, scale_factor
+        )
+
+        self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
+
+    def _configure_down_sampling_blocks(
+        self, channels: List[int], activation: str, pooling_kernel: int
+    ) -> nn.ModuleList:
+        blocks = nn.ModuleList([])
+        for i in range(len(channels) - 1):
+            pooling_kernel = pooling_kernel if i < self.depth - 1 else False
+            dropout_rate = self.dropout_rate if i < 0 else 0
+            blocks += [
+                _DownSamplingBlock(
+                    [channels[i], channels[i + 1], channels[i + 1]],
+                    activation,
+                    self.num_groups,
+                    pooling_kernel,
+                    dropout_rate,
+                    self.kernel_size[i],
+                    self.dilation[i],
+                    self.padding[i],
+                )
+            ]
+
+        return blocks
+
+    def _configure_up_sampling_blocks(
+        self, channels: List[int], activation: str, scale_factor: int,
+    ) -> nn.ModuleList:
+        channels.reverse()
+        self.kernel_size.reverse()
+        self.dilation.reverse()
+        self.padding.reverse()
+        return nn.ModuleList(
+            [
+                _UpSamplingBlock(
+                    [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
+                    activation,
+                    self.num_groups,
+                    scale_factor,
+                    self.dropout_rate,
+                    self.kernel_size[i],
+                    self.dilation[i],
+                    self.padding[i],
+                )
+                for i in range(len(channels) - 2)
+            ]
+        )
+
+    def _encode(self, x: Tensor) -> List[Tensor]:
+        x_skips = []
+        for block in self.encoder_blocks:
+            x, x_skip = block(x)
+            x_skips.append(x_skip)
+        return x_skips
+
+    def _decode(self, x_skips: List[Tensor]) -> Tensor:
+        x = x_skips[-1]
+        for i, block in enumerate(self.decoder_blocks):
+            x = block(x, x_skips[-(i + 2)])
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass with the UNet model."""
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        x = self.conv(x)
+        x_skips = self._encode(x)
+        x = self._decode(x_skips)
+        return self.head(x)
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
new file mode 100644
index 0000000..131a6b4
--- /dev/null
+++ b/text_recognizer/networks/util.py
@@ -0,0 +1,89 @@
+"""Miscellaneous neural network functionality."""
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
+
+from einops import rearrange
+from loguru import logger
+import torch
+from torch import nn
+
+
+def sliding_window(
+    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+    """Creates patches of an image.
+
+    Args:
+        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+        stride (Tuple[int, int]): The stride of the sliding window.
+
+    Returns:
+        torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+    """
+    unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
+    # Preform the sliding window, unsqueeze as the channel dimesion is lost.
+    c = images.shape[1]
+    patches = unfold(images)
+    patches = rearrange(
+        patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
+    )
+    return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+    """Returns the callable activation function."""
+    activation_fns = nn.ModuleDict(
+        [
+            ["elu", nn.ELU(inplace=True)],
+            ["gelu", nn.GELU()],
+            ["glu", nn.GLU()],
+            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+            ["none", nn.Identity()],
+            ["relu", nn.ReLU(inplace=True)],
+            ["selu", nn.SELU(inplace=True)],
+        ]
+    )
+    return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+    """Loads a backbone network."""
+    network_module = importlib.import_module("text_recognizer.networks")
+    backbone_ = getattr(network_module, backbone)
+
+    if "pretrained" in backbone_args:
+        logger.info("Loading pretrained backbone.")
+        checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+            "pretrained"
+        )
+
+        # Loading state directory.
+        state_dict = torch.load(checkpoint_file)
+        network_args = state_dict["network_args"]
+        weights = state_dict["model_state"]
+
+        freeze = False
+        if "freeze" in backbone_args and backbone_args["freeze"] is True:
+            backbone_args.pop("freeze")
+            freeze = True
+        network_args = backbone_args
+
+        # Initializes the network with trained weights.
+        backbone = backbone_(**network_args)
+        backbone.load_state_dict(weights)
+        if freeze:
+            for params in backbone.parameters():
+                params.requires_grad = False
+    else:
+        backbone_ = getattr(network_module, backbone)
+        backbone = backbone_(**backbone_args)
+
+    if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+        backbone = nn.Sequential(
+            *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+        )
+
+    return backbone
diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py
new file mode 100644
index 0000000..efb3701
--- /dev/null
+++ b/text_recognizer/networks/vit.py
@@ -0,0 +1,150 @@
+"""A Vision Transformer.
+
+Inspired by:
+https://openreview.net/pdf?id=YicbFdNTTy
+
+"""
+from typing import Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import Transformer
+
+
+class ViT(nn.Module):
+    """Transfomer for image to sequence prediction."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        expansion_dim: int,
+        patch_dim: Tuple[int, int],
+        image_size: Tuple[int, int],
+        dropout_rate: float,
+        trg_pad_index: int,
+        max_len: int,
+        activation: str = "gelu",
+    ) -> None:
+        super().__init__()
+
+        self.trg_pad_index = trg_pad_index
+        self.patch_dim = patch_dim
+        self.num_patches = image_size[-1] // self.patch_dim[1]
+
+        # Encoder
+        self.patch_to_embedding = nn.Linear(
+            self.patch_dim[0] * self.patch_dim[1], hidden_dim
+        )
+        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
+        self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+        self.dropout = nn.Dropout(dropout_rate)
+        self._init()
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+    def _init(self) -> None:
+        nn.init.normal_(self.character_embedding.weight, std=0.02)
+        # nn.init.normal_(self.pos_embedding.weight, std=0.02)
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def extract_image_features(self, src: Tensor) -> Tensor:
+        """Extracts image features with a backbone neural network.
+
+        It seem like the winning idea was to swap channels and width dimension and collapse
+        the height dimension. The transformer is learning like a baby with this implementation!!! :D
+        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: A input src to the transformer.
+
+        """
+        # If batch dimension is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+
+        patches = rearrange(
+            src,
+            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+            p1=self.patch_dim[0],
+            p2=self.patch_dim[1],
+        )
+
+        # From patches to encoded sequence.
+        x = self.patch_to_embedding(patches)
+        b, n, _ = x.shape
+        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
+        x = torch.cat((cls_tokens, x), dim=1)
+        x += self.pos_embedding[:, : (n + 1)]
+        x = self.dropout(x)
+
+        return x
+
+    def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+        """
+        _, n = trg.shape
+        trg = self.character_embedding(trg.long())
+        trg += self.pos_embedding[:, :n]
+        return trg
+
+    def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Takes images features from the backbone and decodes them with the transformer."""
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.target_embedding(trg)
+        out = self.transformer(h, trg, trg_mask=trg_mask)
+
+        logits = self.head(out)
+        return logits
+
+    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Forward pass with CNN transfomer."""
+        h = self.extract_image_features(x)
+        logits = self.decode_image_features(h, trg)
+        return logits
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c673d96
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,150 @@
+"""A VQ-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class VQTransformer(nn.Module):
+    """VQ+Transfomer for image to character sequence prediction."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        adaptive_pool_dim: Tuple,
+        expansion_dim: int,
+        dropout_rate: float,
+        trg_pad_index: int,
+        max_len: int,
+        backbone: str,
+        backbone_args: Optional[Dict] = None,
+        activation: str = "gelu",
+    ) -> None:
+        super().__init__()
+
+        # Configure vector quantized backbone.
+        self.backbone = configure_backbone(backbone, backbone_args)
+        self.conv = nn.Sequential(
+            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
+            nn.ReLU(inplace=True),
+        )
+
+        # Configure embeddings for Transformer network.
+        self.trg_pad_index = trg_pad_index
+        self.vocab_size = vocab_size
+        self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+        self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+        self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+        nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+        self.adaptive_pool = (
+            nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+        )
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
+        """Extracts image features with a backbone neural network.
+
+        It seem like the winning idea was to swap channels and width dimension and collapse
+        the height dimension. The transformer is learning like a baby with this implementation!!! :D
+        Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: The input src to the transformer and the vq loss.
+
+        """
+        # If batch dimension is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+        src, vq_loss = self.backbone.encode(src)
+        # src = self.backbone.decoder.res_block(src)
+        src = self.conv(src)
+
+        if self.adaptive_pool is not None:
+            src = rearrange(src, "b c h w -> b w c h")
+            src = self.adaptive_pool(src)
+            src = src.squeeze(3)
+        else:
+            src = rearrange(src, "b c h w -> b (w h) c")
+
+        b, t, _ = src.shape
+
+        src += self.src_position_embedding[:, :t]
+
+        return src, vq_loss
+
+    def target_embedding(self, trg: Tensor) -> Tensor:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tensor: Encoded target tensor.
+
+        """
+        trg = self.character_embedding(trg.long())
+        trg = self.trg_position_encoding(trg)
+        return trg
+
+    def decode_image_features(
+        self, image_features: Tensor, trg: Optional[Tensor] = None
+    ) -> Tensor:
+        """Takes images features from the backbone and decodes them with the transformer."""
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.target_embedding(trg)
+        out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+        logits = self.head(out)
+        return logits
+
+    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Forward pass with CNN transfomer."""
+        image_features, vq_loss = self.extract_image_features(x)
+        logits = self.decode_image_features(image_features, trg)
+        return logits, vq_loss
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
new file mode 100644
index 0000000..763953c
--- /dev/null
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -0,0 +1,5 @@
+"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+    """A CNN encoder network."""
+
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        strides: List[int],
+        num_residual_layers: int,
+        embedding_dim: int,
+        upsampling: Optional[List[List[int]]] = None,
+        activation: str = "leaky_relu",
+        dropout_rate: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        if dropout_rate:
+            if activation == "selu":
+                dropout = nn.AlphaDropout(p=dropout_rate)
+            else:
+                dropout = nn.Dropout(p=dropout_rate)
+        else:
+            dropout = None
+
+        self.upsampling = upsampling
+
+        self.res_block = nn.ModuleList([])
+        self.upsampling_block = nn.ModuleList([])
+
+        self.embedding_dim = embedding_dim
+        activation = activation_function(activation)
+
+        # Configure encoder.
+        self.decoder = self._build_decoder(
+            channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+        )
+
+    def _build_decompression_block(
+        self,
+        in_channels: int,
+        channels: int,
+        kernel_sizes: List[int],
+        strides: List[int],
+        activation: Type[nn.Module],
+        dropout: Optional[Type[nn.Module]],
+    ) -> nn.ModuleList:
+        modules = nn.ModuleList([])
+        configuration = zip(channels, kernel_sizes, strides)
+        for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+            modules.append(
+                nn.Sequential(
+                    nn.ConvTranspose2d(
+                        in_channels,
+                        out_channels,
+                        kernel_size,
+                        stride=stride,
+                        padding=1,
+                    ),
+                    activation,
+                )
+            )
+
+            if i < len(self.upsampling):
+                modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+            if dropout is not None:
+                modules.append(dropout)
+
+            in_channels = out_channels
+
+        modules.extend(
+            nn.Sequential(
+                nn.ConvTranspose2d(
+                    in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+                ),
+                nn.Tanh(),
+            )
+        )
+
+        return modules
+
+    def _build_decoder(
+        self,
+        channels: int,
+        kernel_sizes: List[int],
+        strides: List[int],
+        num_residual_layers: int,
+        activation: Type[nn.Module],
+        dropout: Optional[Type[nn.Module]],
+    ) -> nn.Sequential:
+
+        self.res_block.append(
+            nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+        )
+
+        # Bottleneck module.
+        self.res_block.extend(
+            nn.ModuleList(
+                [
+                    _ResidualBlock(channels[0], channels[0], dropout)
+                    for i in range(num_residual_layers)
+                ]
+            )
+        )
+
+        # Decompression module
+        self.upsampling_block.extend(
+            self._build_decompression_block(
+                channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+            )
+        )
+
+        self.res_block = nn.Sequential(*self.res_block)
+        self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+        return nn.Sequential(self.res_block, self.upsampling_block)
+
+    def forward(self, z_q: Tensor) -> Tensor:
+        """Reconstruct input from given codes."""
+        x_reconstruction = self.decoder(z_q)
+        return x_reconstruction
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
new file mode 100644
index 0000000..d3adac5
--- /dev/null
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -0,0 +1,147 @@
+"""CNN encoder for the VQ-VAE."""
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
+
+
+class _ResidualBlock(nn.Module):
+    def __init__(
+        self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+    ) -> None:
+        super().__init__()
+        self.block = [
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
+        ]
+
+        if dropout is not None:
+            self.block.append(dropout)
+
+        self.block = nn.Sequential(*self.block)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Apply the residual forward pass."""
+        return x + self.block(x)
+
+
+class Encoder(nn.Module):
+    """A CNN encoder network."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        channels: List[int],
+        kernel_sizes: List[int],
+        strides: List[int],
+        num_residual_layers: int,
+        embedding_dim: int,
+        num_embeddings: int,
+        beta: float = 0.25,
+        activation: str = "leaky_relu",
+        dropout_rate: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        if dropout_rate:
+            if activation == "selu":
+                dropout = nn.AlphaDropout(p=dropout_rate)
+            else:
+                dropout = nn.Dropout(p=dropout_rate)
+        else:
+            dropout = None
+
+        self.embedding_dim = embedding_dim
+        self.num_embeddings = num_embeddings
+        self.beta = beta
+        activation = activation_function(activation)
+
+        # Configure encoder.
+        self.encoder = self._build_encoder(
+            in_channels,
+            channels,
+            kernel_sizes,
+            strides,
+            num_residual_layers,
+            activation,
+            dropout,
+        )
+
+        # Configure Vector Quantizer.
+        self.vector_quantizer = VectorQuantizer(
+            self.num_embeddings, self.embedding_dim, self.beta
+        )
+
+    def _build_compression_block(
+        self,
+        in_channels: int,
+        channels: int,
+        kernel_sizes: List[int],
+        strides: List[int],
+        activation: Type[nn.Module],
+        dropout: Optional[Type[nn.Module]],
+    ) -> nn.ModuleList:
+        modules = nn.ModuleList([])
+        configuration = zip(channels, kernel_sizes, strides)
+        for out_channels, kernel_size, stride in configuration:
+            modules.append(
+                nn.Sequential(
+                    nn.Conv2d(
+                        in_channels, out_channels, kernel_size, stride=stride, padding=1
+                    ),
+                    activation,
+                )
+            )
+
+            if dropout is not None:
+                modules.append(dropout)
+
+            in_channels = out_channels
+
+        return modules
+
+    def _build_encoder(
+        self,
+        in_channels: int,
+        channels: int,
+        kernel_sizes: List[int],
+        strides: List[int],
+        num_residual_layers: int,
+        activation: Type[nn.Module],
+        dropout: Optional[Type[nn.Module]],
+    ) -> nn.Sequential:
+        encoder = nn.ModuleList([])
+
+        # compression module
+        encoder.extend(
+            self._build_compression_block(
+                in_channels, channels, kernel_sizes, strides, activation, dropout
+            )
+        )
+
+        # Bottleneck module.
+        encoder.extend(
+            nn.ModuleList(
+                [
+                    _ResidualBlock(channels[-1], channels[-1], dropout)
+                    for i in range(num_residual_layers)
+                ]
+            )
+        )
+
+        encoder.append(
+            nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+        )
+
+        return nn.Sequential(*encoder)
+
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes input into a discrete representation."""
+        z_e = self.encoder(x)
+        z_q, vq_loss = self.vector_quantizer(z_e)
+        return z_q, vq_loss
diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py
new file mode 100644
index 0000000..f92c7ee
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -0,0 +1,119 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+
+"""
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizer(nn.Module):
+    """The codebook that contains quantized vectors."""
+
+    def __init__(
+        self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+    ) -> None:
+        super().__init__()
+        self.K = num_embeddings
+        self.D = embedding_dim
+        self.beta = beta
+
+        self.embedding = nn.Embedding(self.K, self.D)
+
+        # Initialize the codebook.
+        nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
+
+    def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+        """Computes the code nearest to the latent representation.
+
+        First we compute the posterior categorical distribution, and then map
+        the latent representation to the nearest element of the embedding.
+
+        Args:
+            latent (Tensor): The latent representation.
+
+        Shape:
+            - latent :math:`(B x H x W, D)`
+
+        Returns:
+            Tensor: The quantized embedding vector.
+
+        """
+        # Store latent shape.
+        b, h, w, d = latent.shape
+
+        # Flatten the latent representation to 2D.
+        latent = rearrange(latent, "b h w d -> (b h w) d")
+
+        # Compute the L2 distance between the latents and the embeddings.
+        l2_distance = (
+            torch.sum(latent ** 2, dim=1, keepdim=True)
+            + torch.sum(self.embedding.weight ** 2, dim=1)
+            - 2 * latent @ self.embedding.weight.t()
+        )  # [BHW x K]
+
+        # Find the embedding k nearest to each latent.
+        encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1)  # [BHW, 1]
+
+        # Convert to one-hot encodings, aka discrete bottleneck.
+        one_hot_encoding = torch.zeros(
+            encoding_indices.shape[0], self.K, device=latent.device
+        )
+        one_hot_encoding.scatter_(1, encoding_indices, 1)  # [BHW x K]
+
+        # Embedding quantization.
+        quantized_latent = one_hot_encoding @ self.embedding.weight  # [BHW, D]
+        quantized_latent = rearrange(
+            quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
+        )
+
+        return quantized_latent
+
+    def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+        """Vector Quantization loss.
+
+        The vector quantization algorithm allows us to create a codebook. The VQ
+        algorithm works by moving the embedding vectors towards the encoder outputs.
+
+        The embedding loss moves the embedding vector towards the encoder outputs. The
+        .detach() works as the stop gradient (sg) described in the paper.
+
+        Because the volume of the embedding space is dimensionless, it can arbitarily
+        grow if the embeddings are not trained as fast as the encoder parameters. To
+        mitigate this, a commitment loss is added in the second term which makes sure
+        that the encoder commits to an embedding and that its output does not grow.
+
+        Args:
+            latent (Tensor): The encoder output.
+            quantized_latent (Tensor): The quantized latent.
+
+        Returns:
+            Tensor: The combinded VQ loss.
+
+        """
+        embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+        commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
+        return embedding_loss + self.beta * commitment_loss
+
+    def forward(self, latent: Tensor) -> Tensor:
+        """Forward pass that returns the quantized vector and the vq loss."""
+        # Rearrange latent representation s.t. the hidden dim is at the end.
+        latent = rearrange(latent, "b d h w -> b h w d")
+
+        # Maps latent to the nearest code in the codebook.
+        quantized_latent = self.discretization_bottleneck(latent)
+
+        loss = self.vq_loss(latent, quantized_latent)
+
+        # Add residue to the quantized latent.
+        quantized_latent = latent + (quantized_latent - latent).detach()
+
+        # Rearrange the quantized shape back to the original shape.
+        quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
+
+        return quantized_latent, loss
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+    """Vector Quantized Variational AutoEncoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        channels: List[int],
+        kernel_sizes: List[int],
+        strides: List[int],
+        num_residual_layers: int,
+        embedding_dim: int,
+        num_embeddings: int,
+        upsampling: Optional[List[List[int]]] = None,
+        beta: float = 0.25,
+        activation: str = "leaky_relu",
+        dropout_rate: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        # configure encoder.
+        self.encoder = Encoder(
+            in_channels,
+            channels,
+            kernel_sizes,
+            strides,
+            num_residual_layers,
+            embedding_dim,
+            num_embeddings,
+            beta,
+            activation,
+            dropout_rate,
+        )
+
+        # Configure decoder.
+        channels.reverse()
+        kernel_sizes.reverse()
+        strides.reverse()
+        self.decoder = Decoder(
+            channels,
+            kernel_sizes,
+            strides,
+            num_residual_layers,
+            embedding_dim,
+            upsampling,
+            activation,
+            dropout_rate,
+        )
+
+    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes input to a latent code."""
+        return self.encoder(x)
+
+    def decode(self, z_q: Tensor) -> Tensor:
+        """Reconstructs input from latent codes."""
+        return self.decoder(z_q)
+
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+        """Compresses and decompresses input."""
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        z_q, vq_loss = self.encode(x)
+        x_reconstruction = self.decode(z_q)
+        return x_reconstruction, vq_loss
diff --git a/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py
new file mode 100644
index 0000000..b767778
--- /dev/null
+++ b/text_recognizer/networks/wide_resnet.py
@@ -0,0 +1,221 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+    """Helper function for a 3x3 2d convolution."""
+    return nn.Conv2d(
+        in_channels=in_planes,
+        out_channels=out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=1,
+        bias=False,
+    )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+    """Initializes the weights for convolution and batchnorms."""
+    classname = module.__class__.__name__
+    if classname.find("Conv") != -1:
+        nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+        nn.init.constant_(module.bias, 0)
+    elif classname.find("BatchNorm") != -1:
+        nn.init.constant_(module.weight, 1)
+        nn.init.constant_(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+    """Block used in WideResNet."""
+
+    def __init__(
+        self,
+        in_planes: int,
+        out_planes: int,
+        dropout_rate: float,
+        stride: int = 1,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.in_planes = in_planes
+        self.out_planes = out_planes
+        self.dropout_rate = dropout_rate
+        self.stride = stride
+        self.activation = activation_function(activation)
+
+        # Build blocks.
+        self.blocks = nn.Sequential(
+            nn.BatchNorm2d(self.in_planes),
+            self.activation,
+            conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+            nn.Dropout(p=self.dropout_rate),
+            nn.BatchNorm2d(self.out_planes),
+            self.activation,
+            conv3x3(
+                in_planes=self.out_planes,
+                out_planes=self.out_planes,
+                stride=self.stride,
+            ),
+        )
+
+        self.shortcut = (
+            nn.Sequential(
+                nn.Conv2d(
+                    in_channels=self.in_planes,
+                    out_channels=self.out_planes,
+                    kernel_size=1,
+                    stride=self.stride,
+                    bias=False,
+                ),
+            )
+            if self._apply_shortcut
+            else None
+        )
+
+    @property
+    def _apply_shortcut(self) -> bool:
+        """If shortcut should be applied or not."""
+        return self.stride != 1 or self.in_planes != self.out_planes
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        residual = x
+        if self._apply_shortcut:
+            residual = self.shortcut(x)
+        x = self.blocks(x)
+        x += residual
+        return x
+
+
+class WideResidualNetwork(nn.Module):
+    """WideResNet for character predictions.
+
+    Can be used for classification or encoding of images to a latent vector.
+
+    """
+
+    def __init__(
+        self,
+        in_channels: int = 1,
+        in_planes: int = 16,
+        num_classes: int = 80,
+        depth: int = 16,
+        width_factor: int = 10,
+        dropout_rate: float = 0.0,
+        num_layers: int = 3,
+        block: Type[nn.Module] = WideBlock,
+        num_stages: Optional[List[int]] = None,
+        activation: str = "relu",
+        use_decoder: bool = True,
+    ) -> None:
+        """The initialization of the WideResNet.
+
+        Args:
+            in_channels (int): Number of input channels. Defaults to 1.
+            in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+            num_classes (int): Number of classes. Defaults to 80.
+            depth (int): Set the number of blocks to use. Defaults to 16.
+            width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+            dropout_rate (float): The dropout rate. Defaults to 0.0.
+            num_layers (int): Number of layers of blocks. Defaults to 3.
+            block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+            num_stages (List[int]): If given, will use these channel values. Defaults to None.
+            activation (str): Name of the activation to use. Defaults to "relu".
+            use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+                latent vector. Defaults to True.
+
+        Raises:
+            RuntimeError: If the depth is not of the size `6n+4`.
+
+        """
+
+        super().__init__()
+        if (depth - 4) % 6 != 0:
+            raise RuntimeError("Wide-resnet depth should be 6n+4")
+        self.in_channels = in_channels
+        self.in_planes = in_planes
+        self.num_classes = num_classes
+        self.num_blocks = (depth - 4) // 6
+        self.width_factor = width_factor
+        self.num_layers = num_layers
+        self.block = block
+        self.dropout_rate = dropout_rate
+        self.activation = activation_function(activation)
+
+        if num_stages is None:
+            self.num_stages = [self.in_planes] + [
+                self.in_planes * 2 ** n * self.width_factor
+                for n in range(self.num_layers)
+            ]
+        else:
+            self.num_stages = [self.in_planes] + num_stages
+
+        self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+        self.strides = [1] + [2] * (self.num_layers - 1)
+
+        self.encoder = nn.Sequential(
+            conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+            *[
+                self._configure_wide_layer(
+                    in_planes=in_planes,
+                    out_planes=out_planes,
+                    stride=stride,
+                    activation=activation,
+                )
+                for (in_planes, out_planes), stride in zip(
+                    self.num_stages, self.strides
+                )
+            ],
+        )
+
+        self.decoder = (
+            nn.Sequential(
+                nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+                self.activation,
+                Reduce("b c h w -> b c", "mean"),
+                nn.Linear(
+                    in_features=self.num_stages[-1][-1], out_features=self.num_classes
+                ),
+            )
+            if use_decoder
+            else None
+        )
+
+        # self.apply(conv_init)
+
+    def _configure_wide_layer(
+        self, in_planes: int, out_planes: int, stride: int, activation: str
+    ) -> List:
+        strides = [stride] + [1] * (self.num_blocks - 1)
+        planes = [out_planes] * len(strides)
+        planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+        return nn.Sequential(
+            *[
+                self.block(
+                    in_planes=in_planes,
+                    out_planes=out_planes,
+                    dropout_rate=self.dropout_rate,
+                    stride=stride,
+                    activation=activation,
+                )
+                for (in_planes, out_planes), stride in zip(planes, strides)
+            ]
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Feedforward pass."""
+        if len(x.shape) < 4:
+            x = x[(None,) * int(4 - len(x.shape))]
+        x = self.encoder(x)
+        if self.decoder is not None:
+            x = self.decoder(x)
+        return x
-- 
cgit v1.2.3-70-g09d2