From dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Sun, 8 Nov 2020 14:54:44 +0100
Subject: new updates

---
 src/text_recognizer/networks/__init__.py           |  20 +-
 src/text_recognizer/networks/cnn_transformer.py    | 135 ++++++++++++
 .../networks/cnn_transformer_encoder.py            |  73 +++++++
 src/text_recognizer/networks/crnn.py               | 108 +++++++++
 src/text_recognizer/networks/ctc.py                |   2 +-
 src/text_recognizer/networks/densenet.py           | 225 +++++++++++++++++++
 src/text_recognizer/networks/lenet.py              |   6 +-
 src/text_recognizer/networks/line_lstm_ctc.py      | 120 ----------
 src/text_recognizer/networks/loss.py               |  69 ++++++
 src/text_recognizer/networks/losses.py             |  31 ---
 src/text_recognizer/networks/misc.py               |  45 ----
 src/text_recognizer/networks/mlp.py                |   6 +-
 src/text_recognizer/networks/residual_network.py   |   6 +-
 src/text_recognizer/networks/sparse_mlp.py         |  78 +++++++
 src/text_recognizer/networks/transformer.py        |   5 -
 .../networks/transformer/__init__.py               |   3 +
 .../networks/transformer/attention.py              |  93 ++++++++
 .../networks/transformer/positional_encoding.py    |  32 +++
 .../networks/transformer/transformer.py            | 242 +++++++++++++++++++++
 src/text_recognizer/networks/util.py               |  83 +++++++
 src/text_recognizer/networks/vision_transformer.py | 159 ++++++++++++++
 src/text_recognizer/networks/wide_resnet.py        |   6 +-
 22 files changed, 1329 insertions(+), 218 deletions(-)
 create mode 100644 src/text_recognizer/networks/cnn_transformer.py
 create mode 100644 src/text_recognizer/networks/cnn_transformer_encoder.py
 create mode 100644 src/text_recognizer/networks/crnn.py
 create mode 100644 src/text_recognizer/networks/densenet.py
 delete mode 100644 src/text_recognizer/networks/line_lstm_ctc.py
 create mode 100644 src/text_recognizer/networks/loss.py
 delete mode 100644 src/text_recognizer/networks/losses.py
 delete mode 100644 src/text_recognizer/networks/misc.py
 create mode 100644 src/text_recognizer/networks/sparse_mlp.py
 delete mode 100644 src/text_recognizer/networks/transformer.py
 create mode 100644 src/text_recognizer/networks/transformer/__init__.py
 create mode 100644 src/text_recognizer/networks/transformer/attention.py
 create mode 100644 src/text_recognizer/networks/transformer/positional_encoding.py
 create mode 100644 src/text_recognizer/networks/transformer/transformer.py
 create mode 100644 src/text_recognizer/networks/util.py
 create mode 100644 src/text_recognizer/networks/vision_transformer.py

(limited to 'src/text_recognizer/networks')

diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a39975f..6d88768 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,21 +1,33 @@
 """Network modules."""
+from .cnn_transformer import CNNTransformer
+from .cnn_transformer_encoder import CNNTransformerEncoder
+from .crnn import ConvolutionalRecurrentNetwork
 from .ctc import greedy_decoder
+from .densenet import DenseNet
 from .lenet import LeNet
-from .line_lstm_ctc import LineRecurrentNetwork
-from .losses import EmbeddingLoss
-from .misc import sliding_window
+from .loss import EmbeddingLoss
 from .mlp import MLP
 from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .sparse_mlp import SparseMLP
+from .transformer import Transformer
+from .util import sliding_window
+from .vision_transformer import VisionTransformer
 from .wide_resnet import WideResidualNetwork
 
 __all__ = [
+    "CNNTransformer",
+    "CNNTransformerEncoder",
+    "ConvolutionalRecurrentNetwork",
+    "DenseNet",
     "EmbeddingLoss",
     "greedy_decoder",
     "MLP",
     "LeNet",
-    "LineRecurrentNetwork",
     "ResidualNetwork",
     "ResidualNetworkEncoder",
     "sliding_window",
+    "Transformer",
+    "SparseMLP",
+    "VisionTransformer",
     "WideResidualNetwork",
 ]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..3da2c9f
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,135 @@
+"""A DETR style transfomers but for text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+    """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        adaptive_pool_dim: Tuple,
+        expansion_dim: int,
+        dropout_rate: float,
+        trg_pad_index: int,
+        backbone: str,
+        out_channels: int,
+        max_len: int,
+        backbone_args: Optional[Dict] = None,
+        activation: str = "gelu",
+    ) -> None:
+        super().__init__()
+        self.trg_pad_index = trg_pad_index
+
+        self.backbone = configure_backbone(backbone, backbone_args)
+        self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+
+        # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1)
+
+        self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+        self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+        self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2))
+
+        self.adaptive_pool = (
+            nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+        )
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def preprocess_input(self, src: Tensor) -> Tensor:
+        """Encodes src with a backbone network and a positional encoding.
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: A input src to the transformer.
+
+        """
+        # If batch dimenstion is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+        src = self.backbone(src)
+        # src = self.conv(src)
+        if self.adaptive_pool is not None:
+            src = self.adaptive_pool(src)
+        H, W = src.shape[-2:]
+        src = rearrange(src, "b t h w -> b t (h w)")
+
+        # construct positional encodings
+        pos = torch.cat(
+            [
+                self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
+                self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
+            ],
+            dim=-1,
+        ).unsqueeze(0)
+        pos = rearrange(pos, "b h w l -> b l (h w)")
+        src = pos + 0.1 * src
+        return src
+
+    def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+        """
+        trg = self.character_embedding(trg.long())
+        trg = self.position_encoding(trg)
+        return trg
+
+    def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+        """Forward pass with CNN transfomer."""
+        h = self.preprocess_input(x)
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.preprocess_target(trg)
+        out = self.transformer(h, trg, trg_mask=trg_mask)
+
+        logits = self.head(out)
+        return logits
diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py
new file mode 100644
index 0000000..93626bf
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer_encoder.py
@@ -0,0 +1,73 @@
+"""Network with a CNN backend and a transformer encoder head."""
+from typing import Dict
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformerEncoder(nn.Module):
+    """A CNN backbone with Transformer Encoder frontend for sequence prediction."""
+
+    def __init__(
+        self,
+        backbone: str,
+        backbone_args: Dict,
+        mlp_dim: int,
+        d_model: int,
+        nhead: int = 8,
+        dropout_rate: float = 0.1,
+        activation: str = "relu",
+        num_layers: int = 6,
+        num_classes: int = 80,
+        num_channels: int = 256,
+        max_len: int = 97,
+    ) -> None:
+        super().__init__()
+        self.d_model = d_model
+        self.nhead = nhead
+        self.dropout_rate = dropout_rate
+        self.activation = activation
+        self.num_layers = num_layers
+
+        self.backbone = configure_backbone(backbone, backbone_args)
+        self.position_encoding = PositionalEncoding(d_model, dropout_rate)
+        self.encoder = self._configure_encoder()
+
+        self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1)
+
+        self.mlp = nn.Linear(mlp_dim, d_model)
+
+        self.head = nn.Linear(d_model, num_classes)
+
+    def _configure_encoder(self) -> nn.TransformerEncoder:
+        encoder_layer = nn.TransformerEncoderLayer(
+            d_model=self.d_model,
+            nhead=self.nhead,
+            dropout=self.dropout_rate,
+            activation=self.activation,
+        )
+        norm = nn.LayerNorm(self.d_model)
+        return nn.TransformerEncoder(
+            encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm
+        )
+
+    def forward(self, x: Tensor, targets: Tensor = None) -> Tensor:
+        """Forward pass through the network."""
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+
+        x = self.conv(self.backbone(x))
+        x = rearrange(x, "b c h w -> b c (h w)")
+        x = self.mlp(x)
+        x = self.position_encoding(x)
+        x = rearrange(x, "b c h-> c b h")
+        x = self.encoder(x)
+        x = rearrange(x, "c b h-> b c h")
+        logits = self.head(x)
+
+        return logits
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..9747429
--- /dev/null
+++ b/src/text_recognizer/networks/crnn.py
@@ -0,0 +1,108 @@
+"""LSTM with CTC for handwritten text recognition within a line."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+    """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+    def __init__(
+        self,
+        backbone: str,
+        backbone_args: Dict = None,
+        input_size: int = 128,
+        hidden_size: int = 128,
+        bidirectional: bool = False,
+        num_layers: int = 1,
+        num_classes: int = 80,
+        patch_size: Tuple[int, int] = (28, 28),
+        stride: Tuple[int, int] = (1, 14),
+        recurrent_cell: str = "lstm",
+        avg_pool: bool = False,
+        use_sliding_window: bool = True,
+    ) -> None:
+        super().__init__()
+        self.backbone_args = backbone_args or {}
+        self.patch_size = patch_size
+        self.stride = stride
+        self.sliding_window = (
+            self._configure_sliding_window() if use_sliding_window else None
+        )
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.backbone = configure_backbone(backbone, backbone_args)
+        self.bidirectional = bidirectional
+        self.avg_pool = avg_pool
+
+        if recurrent_cell.upper() in ["LSTM", "GRU"]:
+            recurrent_cell = getattr(nn, recurrent_cell)
+        else:
+            logger.warning(
+                f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+            )
+            recurrent_cell = nn.LSTM
+
+        self.rnn = recurrent_cell(
+            input_size=self.input_size,
+            hidden_size=self.hidden_size,
+            bidirectional=bidirectional,
+            num_layers=num_layers,
+        )
+
+        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+        self.decoder = nn.Sequential(
+            nn.Linear(in_features=decoder_size, out_features=num_classes),
+            nn.LogSoftmax(dim=2),
+        )
+
+    def _configure_sliding_window(self) -> nn.Sequential:
+        return nn.Sequential(
+            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+            Rearrange(
+                "b (c h w) t -> b t c h w",
+                h=self.patch_size[0],
+                w=self.patch_size[1],
+                c=1,
+            ),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+
+        if self.sliding_window is not None:
+            # Create image patches with a sliding window kernel.
+            x = self.sliding_window(x)
+
+            # Rearrange from a sequence of patches for feedforward network.
+            b, t = x.shape[:2]
+            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+            x = self.backbone(x)
+
+            # Avgerage pooling.
+            if self.avg_pool:
+                x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+            else:
+                x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+        else:
+            # Encode the entire image with a CNN, and use the channels as temporal dimension.
+            b = x.shape[0]
+            x = self.backbone(x)
+            x = rearrange(x, "b c h w -> c b (h w)", b=b)
+
+        # Sequence predictions.
+        x, _ = self.rnn(x)
+
+        # Sequence to classifcation layer.
+        x = self.decoder(x)
+        return x
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index 2493d5c..af9b700 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -33,7 +33,7 @@ def greedy_decoder(
     """
 
     if character_mapper is None:
-        character_mapper = EmnistMapper()
+        character_mapper = EmnistMapper(pad_token="_")  # noqa: S106
 
     predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
     decoded_predictions = []
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..7dc58d9
--- /dev/null
+++ b/src/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+    """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        growth_rate: int,
+        bn_size: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        activation_fn = activation_function(activation)
+        self.dense_layer = [
+            nn.BatchNorm2d(in_channels),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=bn_size * growth_rate,
+                kernel_size=1,
+                stride=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(bn_size * growth_rate),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=bn_size * growth_rate,
+                out_channels=growth_rate,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
+        ]
+        if dropout_rate:
+            self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+        self.dense_layer = nn.Sequential(*self.dense_layer)
+
+    def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+        if isinstance(x, list):
+            x = torch.cat(x, 1)
+        return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+    def __init__(
+        self,
+        num_layers: int,
+        in_channels: int,
+        bn_size: int,
+        growth_rate: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.dense_block = self._build_dense_blocks(
+            num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
+        )
+
+    def _build_dense_blocks(
+        self,
+        num_layers: int,
+        in_channels: int,
+        bn_size: int,
+        growth_rate: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> nn.ModuleList:
+        dense_block = []
+        for i in range(num_layers):
+            dense_block.append(
+                _DenseLayer(
+                    in_channels=in_channels + i * growth_rate,
+                    growth_rate=growth_rate,
+                    bn_size=bn_size,
+                    dropout_rate=dropout_rate,
+                    activation=activation,
+                )
+            )
+        return nn.ModuleList(dense_block)
+
+    def forward(self, x: Tensor) -> Tensor:
+        feature_maps = [x]
+        for layer in self.dense_block:
+            x = layer(feature_maps)
+            feature_maps.append(x)
+        return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+    def __init__(
+        self, in_channels: int, out_channels: int, activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        activation_fn = activation_function(activation)
+        self.transition = nn.Sequential(
+            nn.BatchNorm2d(in_channels),
+            activation_fn,
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                stride=1,
+                bias=False,
+            ),
+            nn.AvgPool2d(kernel_size=2, stride=2),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.transition(x)
+
+
+class DenseNet(nn.Module):
+    """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+    def __init__(
+        self,
+        growth_rate: int = 32,
+        block_config: List[int] = (6, 12, 24, 16),
+        in_channels: int = 1,
+        base_channels: int = 64,
+        num_classes: int = 80,
+        bn_size: int = 4,
+        dropout_rate: float = 0,
+        classifier: bool = True,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.densenet = self._configure_densenet(
+            in_channels,
+            base_channels,
+            num_classes,
+            growth_rate,
+            block_config,
+            bn_size,
+            dropout_rate,
+            classifier,
+            activation,
+        )
+
+    def _configure_densenet(
+        self,
+        in_channels: int,
+        base_channels: int,
+        num_classes: int,
+        growth_rate: int,
+        block_config: List[int],
+        bn_size: int,
+        dropout_rate: float,
+        classifier: bool,
+        activation: str,
+    ) -> nn.Sequential:
+        activation_fn = activation_function(activation)
+        densenet = [
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=base_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(base_channels),
+            activation_fn,
+        ]
+
+        num_features = base_channels
+
+        for i, num_layers in enumerate(block_config):
+            densenet.append(
+                _DenseBlock(
+                    num_layers=num_layers,
+                    in_channels=num_features,
+                    bn_size=bn_size,
+                    growth_rate=growth_rate,
+                    dropout_rate=dropout_rate,
+                    activation=activation,
+                )
+            )
+            num_features = num_features + num_layers * growth_rate
+            if i != len(block_config) - 1:
+                densenet.append(
+                    _Transition(
+                        in_channels=num_features,
+                        out_channels=num_features // 2,
+                        activation=activation,
+                    )
+                )
+                num_features = num_features // 2
+
+        densenet.append(activation_fn)
+
+        if classifier:
+            densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+            densenet.append(Rearrange("b c h w -> b (c h w)"))
+            densenet.append(
+                nn.Linear(in_features=num_features, out_features=num_classes)
+            )
+
+        return nn.Sequential(*densenet)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass of Densenet."""
+        # If batch dimenstion is missing, it will be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.densenet(x)
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 53c575e..527e1a0 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
 import torch
 from torch import nn
 
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
 
 
 class LeNet(nn.Module):
@@ -63,6 +63,6 @@ class LeNet(nn.Module):
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """The feedforward pass."""
         # If batch dimenstion is missing, it needs to be added.
-        if len(x.shape) == 3:
-            x = x.unsqueeze(0)
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
         return self.layers(x)
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py
deleted file mode 100644
index 9009f94..0000000
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""LSTM with CTC for handwritten text recognition within a line."""
-import importlib
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
-
-from einops import rearrange, reduce
-from einops.layers.torch import Rearrange, Reduce
-from loguru import logger
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class LineRecurrentNetwork(nn.Module):
-    """Network that takes a image of a text line and predicts tokens that are in the image."""
-
-    def __init__(
-        self,
-        backbone: str,
-        backbone_args: Dict = None,
-        flatten: bool = True,
-        input_size: int = 128,
-        hidden_size: int = 128,
-        bidirectional: bool = False,
-        num_layers: int = 1,
-        num_classes: int = 80,
-        patch_size: Tuple[int, int] = (28, 28),
-        stride: Tuple[int, int] = (1, 14),
-    ) -> None:
-        super().__init__()
-        self.backbone_args = backbone_args or {}
-        self.patch_size = patch_size
-        self.stride = stride
-        self.sliding_window = self._configure_sliding_window()
-        self.input_size = input_size
-        self.hidden_size = hidden_size
-        self.backbone = self._configure_backbone(backbone)
-        self.bidirectional = bidirectional
-        self.flatten = flatten
-
-        if self.flatten:
-            self.fc = nn.Linear(
-                in_features=self.input_size, out_features=self.hidden_size
-            )
-
-        self.rnn = nn.LSTM(
-            input_size=self.hidden_size,
-            hidden_size=self.hidden_size,
-            bidirectional=bidirectional,
-            num_layers=num_layers,
-        )
-
-        decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
-
-        self.decoder = nn.Sequential(
-            nn.Linear(in_features=decoder_size, out_features=num_classes),
-            nn.LogSoftmax(dim=2),
-        )
-
-    def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
-        network_module = importlib.import_module("text_recognizer.networks")
-        backbone_ = getattr(network_module, backbone)
-
-        if "pretrained" in self.backbone_args:
-            logger.info("Loading pretrained backbone.")
-            checkpoint_file = Path(__file__).resolve().parents[
-                2
-            ] / self.backbone_args.pop("pretrained")
-
-            # Loading state directory.
-            state_dict = torch.load(checkpoint_file)
-            network_args = state_dict["network_args"]
-            weights = state_dict["model_state"]
-
-            # Initializes the network with trained weights.
-            backbone = backbone_(**network_args)
-            backbone.load_state_dict(weights)
-            if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
-                for params in backbone.parameters():
-                    params.requires_grad = False
-
-            return backbone
-        else:
-            return backbone_(**self.backbone_args)
-
-    def _configure_sliding_window(self) -> nn.Sequential:
-        return nn.Sequential(
-            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
-            Rearrange(
-                "b (c h w) t -> b t c h w",
-                h=self.patch_size[0],
-                w=self.patch_size[1],
-                c=1,
-            ),
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
-        if len(x.shape) == 3:
-            x = x.unsqueeze(0)
-        x = self.sliding_window(x)
-
-        # Rearrange from a sequence of patches for feedforward network.
-        b, t = x.shape[:2]
-        x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
-        x = self.backbone(x)
-
-        # Avgerage pooling.
-        x = (
-            self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
-            if self.flatten
-            else rearrange(x, "(b t) h -> t b h", b=b, t=t)
-        )
-
-        # Sequence predictions.
-        x, _ = self.rnn(x)
-
-        # Sequence to classifcation layer.
-        x = self.decoder(x)
-        return x
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
new file mode 100644
index 0000000..cf9fa0d
--- /dev/null
+++ b/src/text_recognizer/networks/loss.py
@@ -0,0 +1,69 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
+from torch import nn
+from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
+
+
+class EmbeddingLoss:
+    """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+        self.distance = distances.CosineSimilarity()
+        self.reducer = reducers.ThresholdReducer(low=0)
+        self.loss_fn = losses.TripletMarginLoss(
+            margin=margin, distance=self.distance, reducer=self.reducer
+        )
+        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+        """Computes the metric loss for the embeddings based on their labels.
+
+        Args:
+            embeddings (Tensor): The laten vectors encoded by the network.
+            labels (Tensor): Labels of the embeddings.
+
+        Returns:
+            Tensor: The metric loss for the embeddings.
+
+        """
+        hard_pairs = self.miner(embeddings, labels)
+        loss = self.loss_fn(embeddings, labels, hard_pairs)
+        return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+    """Label smoothing loss function."""
+
+    def __init__(
+        self,
+        classes: int,
+        smoothing: float = 0.0,
+        ignore_index: int = None,
+        dim: int = -1,
+    ) -> None:
+        super().__init__()
+        self.confidence = 1.0 - smoothing
+        self.smoothing = smoothing
+        self.ignore_index = ignore_index
+        self.cls = classes
+        self.dim = dim
+
+    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+        """Calculates the loss."""
+        pred = pred.log_softmax(dim=self.dim)
+        with torch.no_grad():
+            # true_dist = pred.data.clone()
+            true_dist = torch.zeros_like(pred)
+            true_dist.fill_(self.smoothing / (self.cls - 1))
+            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+            if self.ignore_index is not None:
+                true_dist[:, self.ignore_index] = 0
+                mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+                if mask.dim() > 0:
+                    true_dist.index_fill_(0, mask.squeeze(), 0.0)
+        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py
deleted file mode 100644
index 73e0641..0000000
--- a/src/text_recognizer/networks/losses.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-from torch import nn
-from torch import Tensor
-
-
-class EmbeddingLoss:
-    """Metric loss for training encoders to produce information-rich latent embeddings."""
-
-    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
-        self.distance = distances.CosineSimilarity()
-        self.reducer = reducers.ThresholdReducer(low=0)
-        self.loss_fn = losses.TripletMarginLoss(
-            margin=margin, distance=self.distance, reducer=self.reducer
-        )
-        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
-    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
-        """Computes the metric loss for the embeddings based on their labels.
-
-        Args:
-            embeddings (Tensor): The laten vectors encoded by the network.
-            labels (Tensor): Labels of the embeddings.
-
-        Returns:
-            Tensor: The metric loss for the embeddings.
-
-        """
-        hard_pairs = self.miner(embeddings, labels)
-        loss = self.loss_fn(embeddings, labels, hard_pairs)
-        return loss
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
deleted file mode 100644
index 1f853e9..0000000
--- a/src/text_recognizer/networks/misc.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""Miscellaneous neural network functionality."""
-from typing import Tuple, Type
-
-from einops import rearrange
-import torch
-from torch import nn
-
-
-def sliding_window(
-    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
-) -> torch.Tensor:
-    """Creates patches of an image.
-
-    Args:
-        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
-        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
-        stride (Tuple[int, int]): The stride of the sliding window.
-
-    Returns:
-        torch.Tensor: A tensor with the shape (batch, patches, height, width).
-
-    """
-    unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
-    # Preform the slidning window, unsqueeze as the channel dimesion is lost.
-    c = images.shape[1]
-    patches = unfold(images)
-    patches = rearrange(
-        patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
-    )
-    return patches
-
-
-def activation_function(activation: str) -> Type[nn.Module]:
-    """Returns the callable activation function."""
-    activation_fns = nn.ModuleDict(
-        [
-            ["elu", nn.ELU(inplace=True)],
-            ["gelu", nn.GELU()],
-            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
-            ["none", nn.Identity()],
-            ["relu", nn.ReLU(inplace=True)],
-            ["selu", nn.SELU(inplace=True)],
-        ]
-    )
-    return activation_fns[activation.lower()]
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d66af28..1101912 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
 import torch
 from torch import nn
 
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
 
 
 class MLP(nn.Module):
@@ -63,8 +63,8 @@ class MLP(nn.Module):
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """The feedforward pass."""
         # If batch dimenstion is missing, it needs to be added.
-        if len(x.shape) == 3:
-            x = x.unsqueeze(0)
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
         return self.layers(x)
 
     @property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 046600d..6405192 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,8 +7,8 @@ import torch
 from torch import nn
 from torch import Tensor
 
-from text_recognizer.networks.misc import activation_function
 from text_recognizer.networks.stn import SpatialTransformerNetwork
+from text_recognizer.networks.util import activation_function
 
 
 class Conv2dAuto(nn.Conv2d):
@@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module):
                 in_channels=in_channels,
                 out_channels=self.block_sizes[0],
                 kernel_size=3,
-                stride=2,
-                padding=3,
+                stride=1,
+                padding=1,
                 bias=False,
             ),
             nn.BatchNorm2d(self.block_sizes[0]),
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py
new file mode 100644
index 0000000..53cf166
--- /dev/null
+++ b/src/text_recognizer/networks/sparse_mlp.py
@@ -0,0 +1,78 @@
+"""Defines the Sparse MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+import warnings
+
+from einops.layers.torch import Rearrange
+from pytorch_block_sparse import BlockSparseLinear
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+class SparseMLP(nn.Module):
+    """Sparse multi layered perceptron network."""
+
+    def __init__(
+        self,
+        input_size: int = 784,
+        num_classes: int = 10,
+        hidden_size: Union[int, List] = 128,
+        num_layers: int = 3,
+        density: float = 0.1,
+        activation_fn: str = "relu",
+    ) -> None:
+        """Initialization of the MLP network.
+
+        Args:
+            input_size (int): The input shape of the network. Defaults to 784.
+            num_classes (int): Number of classes in the dataset. Defaults to 10.
+            hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+            num_layers (int): The number of hidden layers. Defaults to 3.
+            density (float): The density of activation at each layer. Default to 0.1.
+            activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+                relu.
+
+        """
+        super().__init__()
+
+        activation_fn = activation_function(activation_fn)
+
+        if isinstance(hidden_size, int):
+            hidden_size = [hidden_size] * num_layers
+
+        self.layers = [
+            Rearrange("b c h w -> b (c h w)"),
+            nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+            activation_fn,
+        ]
+
+        for i in range(num_layers - 1):
+            self.layers += [
+                BlockSparseLinear(
+                    in_features=hidden_size[i],
+                    out_features=hidden_size[i + 1],
+                    density=density,
+                ),
+                activation_fn,
+            ]
+
+        self.layers.append(
+            nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+        )
+
+        self.layers = nn.Sequential(*self.layers)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """The feedforward pass."""
+        # If batch dimenstion is missing, it needs to be added.
+        if len(x.shape) < 4:
+            x = x[(None,) * (4 - len(x.shape))]
+        return self.layers(x)
+
+    @property
+    def __name__(self) -> str:
+        """Returns the name of the network."""
+        return "mlp"
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
deleted file mode 100644
index c091ba0..0000000
--- a/src/text_recognizer/networks/transformer.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""TBC."""
-from typing import Dict
-
-import torch
-from torch import Tensor
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..020a917
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+    """Implementation of multihead attention."""
+
+    def __init__(
+        self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+    ) -> None:
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.fc_q = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_k = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_v = nn.Linear(
+            in_features=hidden_dim, out_features=hidden_dim, bias=False
+        )
+        self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+        self._init_weights()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def _init_weights(self) -> None:
+        nn.init.normal_(
+            self.fc_q.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.normal_(
+            self.fc_k.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.normal_(
+            self.fc_v.weight,
+            mean=0,
+            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+        )
+        nn.init.xavier_normal_(self.fc_out.weight)
+
+    def scaled_dot_product_attention(
+        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+    ) -> Tensor:
+        """Calculates the scaled dot product attention."""
+
+        # Compute the energy.
+        energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+            query.shape[-1]
+        )
+
+        # If we have a mask for padding some inputs.
+        if mask is not None:
+            energy = energy.masked_fill(mask == 0, -np.inf)
+
+        # Compute the attention from the energy.
+        attention = torch.softmax(energy, dim=3)
+
+        out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+        out = rearrange(out, "b head l v -> b l (head v)")
+        return out, attention
+
+    def forward(
+        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+    ) -> Tuple[Tensor, Tensor]:
+        """Forward pass for computing the multihead attention."""
+        # Get the query, key, and value tensor.
+        query = rearrange(
+            self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+        )
+        key = rearrange(
+            self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+        )
+        value = rearrange(
+            self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+        )
+
+        out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+        out = self.fc_out(out)
+        out = self.dropout(out)
+        return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+    """Encodes a sense of distance or time for transformer networks."""
+
+    def __init__(
+        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+    ) -> None:
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout_rate)
+        self.max_len = max_len
+
+        pe = torch.zeros(max_len, hidden_dim)
+        position = torch.arange(0, max_len).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+        )
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        self.register_buffer("pe", pe)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Encodes the tensor with a postional embedding."""
+        x = x + self.pe[:, : x.shape[1]]
+        return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..c6e943e
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,242 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+    """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+    def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+        super().__init__()
+        self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+        return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+    def __init__(
+        self,
+        hidden_dim: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.layer = nn.Sequential(
+            nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
+            activation_function(activation),
+            nn.Dropout(p=dropout_rate),
+            nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+    """Transfomer encoding layer."""
+
+    def __init__(
+        self,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+        self.cnn = _ConvolutionalLayer(
+            hidden_dim, expansion_dim, dropout_rate, activation
+        )
+        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+        """Forward pass through the encoder."""
+        # First block.
+        # Multi head attention.
+        out, _ = self.self_attention(src, src, src, mask)
+
+        # Add & norm.
+        out = self.block1(out, src)
+
+        # Second block.
+        # Apply 1D-convolution.
+        cnn_out = self.cnn(out)
+
+        # Add & norm.
+        out = self.block2(cnn_out, out)
+
+        return out
+
+
+class Encoder(nn.Module):
+    """Transfomer encoder module."""
+
+    def __init__(
+        self,
+        num_layers: int,
+        encoder_layer: Type[nn.Module],
+        norm: Optional[Type[nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.norm = norm
+
+    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+        """Forward pass through all encoder layers."""
+        for layer in self.layers:
+            src = layer(src, src_mask)
+
+        if self.norm is not None:
+            src = self.norm(src)
+
+        return src
+
+
+class DecoderLayer(nn.Module):
+    """Transfomer decoder layer."""
+
+    def __init__(
+        self,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float = 0.0,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+        self.multihead_attention = MultiHeadAttention(
+            hidden_dim, num_heads, dropout_rate
+        )
+        self.cnn = _ConvolutionalLayer(
+            hidden_dim, expansion_dim, dropout_rate, activation
+        )
+        self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+        self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+    def forward(
+        self,
+        trg: Tensor,
+        memory: Tensor,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass of the layer."""
+        out, _ = self.self_attention(trg, trg, trg, trg_mask)
+        trg = self.block1(out, trg)
+
+        out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+        trg = self.block2(out, trg)
+
+        out = self.cnn(trg)
+        out = self.block3(out, trg)
+
+        return out
+
+
+class Decoder(nn.Module):
+    """Transfomer decoder module."""
+
+    def __init__(
+        self,
+        decoder_layer: Type[nn.Module],
+        num_layers: int,
+        norm: Optional[Type[nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(
+        self,
+        trg: Tensor,
+        memory: Tensor,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass through the decoder."""
+        for layer in self.layers:
+            trg = layer(trg, memory, trg_mask, memory_mask)
+
+        if self.norm is not None:
+            trg = self.norm(trg)
+
+        return trg
+
+
+class Transformer(nn.Module):
+    """Transformer network."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        num_heads: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+
+        # Configure encoder.
+        encoder_norm = nn.LayerNorm(hidden_dim)
+        encoder_layer = EncoderLayer(
+            hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+        )
+        self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+        # Configure decoder.
+        decoder_norm = nn.LayerNorm(hidden_dim)
+        decoder_layer = DecoderLayer(
+            hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+        )
+        self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(
+        self,
+        src: Tensor,
+        trg: Tensor,
+        src_mask: Optional[Tensor] = None,
+        trg_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass through the transformer."""
+        if src.shape[0] != trg.shape[0]:
+            print(trg.shape)
+            raise RuntimeError("The batch size of the src and trg must be the same.")
+        if src.shape[2] != trg.shape[2]:
+            raise RuntimeError(
+                "The number of features for the src and trg must be the same."
+            )
+
+        memory = self.encoder(src, src_mask)
+        output = self.decoder(trg, memory, trg_mask, memory_mask)
+        return output
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
new file mode 100644
index 0000000..b31e640
--- /dev/null
+++ b/src/text_recognizer/networks/util.py
@@ -0,0 +1,83 @@
+"""Miscellaneous neural network functionality."""
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
+
+from einops import rearrange
+from loguru import logger
+import torch
+from torch import nn
+
+
+def sliding_window(
+    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+    """Creates patches of an image.
+
+    Args:
+        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+        stride (Tuple[int, int]): The stride of the sliding window.
+
+    Returns:
+        torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+    """
+    unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
+    # Preform the slidning window, unsqueeze as the channel dimesion is lost.
+    c = images.shape[1]
+    patches = unfold(images)
+    patches = rearrange(
+        patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
+    )
+    return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+    """Returns the callable activation function."""
+    activation_fns = nn.ModuleDict(
+        [
+            ["elu", nn.ELU(inplace=True)],
+            ["gelu", nn.GELU()],
+            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+            ["none", nn.Identity()],
+            ["relu", nn.ReLU(inplace=True)],
+            ["selu", nn.SELU(inplace=True)],
+        ]
+    )
+    return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+    """Loads a backbone network."""
+    network_module = importlib.import_module("text_recognizer.networks")
+    backbone_ = getattr(network_module, backbone)
+
+    if "pretrained" in backbone_args:
+        logger.info("Loading pretrained backbone.")
+        checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+            "pretrained"
+        )
+
+        # Loading state directory.
+        state_dict = torch.load(checkpoint_file)
+        network_args = state_dict["network_args"]
+        weights = state_dict["model_state"]
+
+        # Initializes the network with trained weights.
+        backbone = backbone_(**network_args)
+        backbone.load_state_dict(weights)
+        if "freeze" in backbone_args and backbone_args["freeze"] is True:
+            for params in backbone.parameters():
+                params.requires_grad = False
+
+    else:
+        backbone_ = getattr(network_module, backbone)
+        backbone = backbone_(**backbone_args)
+
+    if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+        backbone = nn.Sequential(
+            *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+        )
+
+    return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
new file mode 100644
index 0000000..f227954
--- /dev/null
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -0,0 +1,159 @@
+"""VisionTransformer module.
+
+Splits each image into patches and feeds them to a transformer.
+
+"""
+
+from typing import Dict, Optional, Tuple, Type
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class VisionTransformer(nn.Module):
+    """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT."""
+
+    def __init__(
+        self,
+        num_encoder_layers: int,
+        num_decoder_layers: int,
+        hidden_dim: int,
+        vocab_size: int,
+        num_heads: int,
+        max_len: int,
+        expansion_dim: int,
+        dropout_rate: float,
+        trg_pad_index: int,
+        mlp_dim: Optional[int] = None,
+        patch_size: Tuple[int, int] = (28, 28),
+        stride: Tuple[int, int] = (1, 14),
+        activation: str = "gelu",
+        backbone: Optional[str] = None,
+        backbone_args: Optional[Dict] = None,
+    ) -> None:
+        super().__init__()
+
+        self.patch_size = patch_size
+        self.stride = stride
+        self.trg_pad_index = trg_pad_index
+        self.slidning_window = self._configure_sliding_window()
+        self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+        self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+        self.mlp_dim = mlp_dim
+
+        self.use_backbone = False
+        if backbone is None:
+            self.linear_projection = nn.Linear(
+                self.patch_size[0] * self.patch_size[1], hidden_dim
+            )
+        else:
+            self.backbone = configure_backbone(backbone, backbone_args)
+            if mlp_dim:
+                self.mlp = nn.Linear(mlp_dim, hidden_dim)
+            self.use_backbone = True
+
+        self.transformer = Transformer(
+            num_encoder_layers,
+            num_decoder_layers,
+            hidden_dim,
+            num_heads,
+            expansion_dim,
+            dropout_rate,
+            activation,
+        )
+
+        self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+    def _configure_sliding_window(self) -> nn.Sequential:
+        return nn.Sequential(
+            nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+            Rearrange(
+                "b (c h w) t -> b t c h w",
+                h=self.patch_size[0],
+                w=self.patch_size[1],
+                c=1,
+            ),
+        )
+
+    def _create_trg_mask(self, trg: Tensor) -> Tensor:
+        # Move this outside the transformer.
+        trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+        trg_len = trg.shape[1]
+        trg_sub_mask = torch.tril(
+            torch.ones((trg_len, trg_len), device=trg.device)
+        ).bool()
+        trg_mask = trg_pad_mask & trg_sub_mask
+        return trg_mask
+
+    def encoder(self, src: Tensor) -> Tensor:
+        """Forward pass with the encoder of the transformer."""
+        return self.transformer.encoder(src)
+
+    def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+        """Forward pass with the decoder of the transformer + classification head."""
+        return self.head(
+            self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+        )
+
+    def _backbone(self, x: Tensor) -> Tensor:
+        b, t = x.shape[:2]
+        if self.use_backbone:
+            x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+            x = self.backbone(x)
+            if self.mlp_dim:
+                x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t)
+                x = self.mlp(x)
+            else:
+                x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+        else:
+            x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
+            x = self.linear_projection(x)
+        return x
+
+    def preprocess_input(self, src: Tensor) -> Tensor:
+        """Encodes src with a backbone network and a positional encoding.
+
+        Args:
+            src (Tensor): Input tensor.
+
+        Returns:
+            Tensor: A input src to the transformer.
+
+        """
+        # If batch dimenstion is missing, it needs to be added.
+        if len(src.shape) < 4:
+            src = src[(None,) * (4 - len(src.shape))]
+        src = self.slidning_window(src)  # .squeeze(-2)
+        src = self._backbone(src)
+        src = self.position_encoding(src)
+        return src
+
+    def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+        """Encodes target tensor with embedding and postion.
+
+        Args:
+            trg (Tensor): Target tensor.
+
+        Returns:
+            Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+        """
+        trg_mask = self._create_trg_mask(trg)
+        trg = self.character_embedding(trg.long())
+        trg = self.position_encoding(trg)
+        return trg, trg_mask
+
+    def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+        """Forward pass with vision transfomer."""
+        src = self.preprocess_input(x)
+        trg, trg_mask = self.preprocess_target(trg)
+        out = self.transformer(src, trg, trg_mask=trg_mask)
+        logits = self.head(out)
+        return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 618f414..aa79c12 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -8,7 +8,7 @@ import torch
 from torch import nn
 from torch import Tensor
 
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
 
 
 def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
@@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module):
 
     def forward(self, x: Tensor) -> Tensor:
         """Feedforward pass."""
-        if len(x.shape) == 3:
-            x = x.unsqueeze(0)
+        if len(x.shape) < 4:
+            x = x[(None,) * int(4 - len(x.shape))]
         x = self.encoder(x)
         if self.decoder is not None:
             x = self.decoder(x)
-- 
cgit v1.2.3-70-g09d2