summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py18
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py111
-rw-r--r--src/text_recognizer/networks/crnn.py (renamed from src/text_recognizer/networks/line_lstm_ctc.py)58
-rw-r--r--src/text_recognizer/networks/densenet.py225
-rw-r--r--src/text_recognizer/networks/lenet.py6
-rw-r--r--src/text_recognizer/networks/loss.py (renamed from src/text_recognizer/networks/losses.py)3
-rw-r--r--src/text_recognizer/networks/mlp.py6
-rw-r--r--src/text_recognizer/networks/residual_network.py6
-rw-r--r--src/text_recognizer/networks/sparse_mlp.py78
-rw-r--r--src/text_recognizer/networks/transformer.py5
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py31
-rw-r--r--src/text_recognizer/networks/transformer/sparse_transformer.py1
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py241
-rw-r--r--src/text_recognizer/networks/util.py (renamed from src/text_recognizer/networks/misc.py)40
-rw-r--r--src/text_recognizer/networks/vision_transformer.py158
-rw-r--r--src/text_recognizer/networks/wide_resnet.py6
18 files changed, 1025 insertions, 64 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index a39975f..8b87797 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -1,21 +1,31 @@
"""Network modules."""
+from .cnn_transformer import CNNTransformer
+from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
+from .densenet import DenseNet
from .lenet import LeNet
-from .line_lstm_ctc import LineRecurrentNetwork
-from .losses import EmbeddingLoss
-from .misc import sliding_window
+from .loss import EmbeddingLoss
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .sparse_mlp import SparseMLP
+from .transformer import Transformer
+from .util import sliding_window
+from .vision_transformer import VisionTransformer
from .wide_resnet import WideResidualNetwork
__all__ = [
+ "CNNTransformer",
+ "ConvolutionalRecurrentNetwork",
+ "DenseNet",
"EmbeddingLoss",
"greedy_decoder",
"MLP",
"LeNet",
- "LineRecurrentNetwork",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
+ "Transformer",
+ "SparseMLP",
+ "VisionTransformer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..8666f11
--- /dev/null
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,111 @@
+"""A DETR style transfomers but for text recognition."""
+from typing import Dict, Optional, Tuple, Type
+
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+ """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ max_len: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+ self.trg_pad_index = trg_pad_index
+ self.backbone_args = backbone_args
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+ self.collapse_spatial_dim = nn.Sequential(
+ Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim))
+ )
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+ self.head = nn.Linear(hidden_dim, vocab_size)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.backbone(src)
+ src = self.collapse_spatial_dim(src)
+ src = self.position_encoding(src)
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg, trg_mask
+
+ def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ src = self.preprocess_input(x)
+ trg, trg_mask = self.preprocess_target(trg)
+ out = self.transformer(src, trg, trg_mask=trg_mask)
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/crnn.py
index 9009f94..3e605e2 100644
--- a/src/text_recognizer/networks/line_lstm_ctc.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -10,15 +10,16 @@ import torch
from torch import nn
from torch import Tensor
+from text_recognizer.networks.util import configure_backbone
-class LineRecurrentNetwork(nn.Module):
+
+class ConvolutionalRecurrentNetwork(nn.Module):
"""Network that takes a image of a text line and predicts tokens that are in the image."""
def __init__(
self,
backbone: str,
backbone_args: Dict = None,
- flatten: bool = True,
input_size: int = 128,
hidden_size: int = 128,
bidirectional: bool = False,
@@ -26,6 +27,7 @@ class LineRecurrentNetwork(nn.Module):
num_classes: int = 80,
patch_size: Tuple[int, int] = (28, 28),
stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
) -> None:
super().__init__()
self.backbone_args = backbone_args or {}
@@ -34,17 +36,19 @@ class LineRecurrentNetwork(nn.Module):
self.sliding_window = self._configure_sliding_window()
self.input_size = input_size
self.hidden_size = hidden_size
- self.backbone = self._configure_backbone(backbone)
+ self.backbone = configure_backbone(backbone, backbone_args)
self.bidirectional = bidirectional
- self.flatten = flatten
- if self.flatten:
- self.fc = nn.Linear(
- in_features=self.input_size, out_features=self.hidden_size
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
)
+ recurrent_cell = nn.LSTM
- self.rnn = nn.LSTM(
- input_size=self.hidden_size,
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
hidden_size=self.hidden_size,
bidirectional=bidirectional,
num_layers=num_layers,
@@ -57,32 +61,6 @@ class LineRecurrentNetwork(nn.Module):
nn.LogSoftmax(dim=2),
)
- def _configure_backbone(self, backbone: str) -> Type[nn.Module]:
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in self.backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[
- 2
- ] / self.backbone_args.pop("pretrained")
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True:
- for params in backbone.parameters():
- params.requires_grad = False
-
- return backbone
- else:
- return backbone_(**self.backbone_args)
-
def _configure_sliding_window(self) -> nn.Sequential:
return nn.Sequential(
nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
@@ -96,8 +74,8 @@ class LineRecurrentNetwork(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
x = self.sliding_window(x)
# Rearrange from a sequence of patches for feedforward network.
@@ -106,11 +84,7 @@ class LineRecurrentNetwork(nn.Module):
x = self.backbone(x)
# Avgerage pooling.
- x = (
- self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t))
- if self.flatten
- else rearrange(x, "(b t) h -> t b h", b=b, t=t)
- )
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
# Sequence predictions.
x, _ = self.rnn(x)
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..d2aad60
--- /dev/null
+++ b/src/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+ """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ growth_rate: int,
+ bn_size: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.dense_layer = [
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(bn_size * growth_rate),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=bn_size * growth_rate,
+ out_channels=growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if dropout_rate:
+ self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+ self.dense_layer = nn.Sequential(*self.dense_layer)
+
+ def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.dense_block = self._build_dense_blocks(
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation
+ )
+
+ def _build_dense_blocks(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> nn.ModuleList:
+ dense_block = []
+ for i in range(num_layers):
+ dense_block.append(
+ _DenseLayer(
+ in_channels=in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ return nn.ModuleList(dense_block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ feature_maps = [x]
+ for layer in self.dense_block:
+ x = layer(feature_maps)
+ feature_maps.append(x)
+ return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.transition = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.transition(x)
+
+
+class DenseNet(nn.Module):
+ """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+ def __init__(
+ self,
+ growth_rate: int = 32,
+ block_config: List[int] = (6, 12, 24, 16),
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 80,
+ bn_size: int = 4,
+ dropout_rate: float = 0,
+ classifier: bool = True,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.densenet = self._configure_densenet(
+ in_channels,
+ base_channels,
+ num_classes,
+ growth_rate,
+ block_config,
+ bn_size,
+ dropout_rate,
+ classifier,
+ activation,
+ )
+
+ def _configure_densenet(
+ self,
+ in_channels: int,
+ base_channels: int,
+ num_classes: int,
+ growth_rate: int,
+ block_config: List[int],
+ bn_size: int,
+ dropout_rate: float,
+ classifier: bool,
+ activation: str,
+ ) -> nn.Sequential:
+ activation_fn = activation_function(activation)
+ densenet = [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(base_channels),
+ activation_fn,
+ ]
+
+ num_features = base_channels
+
+ for i, num_layers in enumerate(block_config):
+ densenet.append(
+ _DenseBlock(
+ num_layers=num_layers,
+ in_channels=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ densenet.append(
+ _Transition(
+ in_channels=num_features,
+ out_channels=num_features // 2,
+ activation=activation,
+ )
+ )
+ num_features = num_features // 2
+
+ densenet.append(activation_fn)
+
+ if classifier:
+ densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+ densenet.append(Rearrange("b c h w -> b (c h w)"))
+ densenet.append(
+ nn.Linear(in_features=num_features, out_features=num_classes)
+ )
+
+ return nn.Sequential(*densenet)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of Densenet."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.densenet(x)
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
index 53c575e..527e1a0 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/src/text_recognizer/networks/lenet.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class LeNet(nn.Module):
@@ -63,6 +63,6 @@ class LeNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/loss.py
index 73e0641..ff843cf 100644
--- a/src/text_recognizer/networks/losses.py
+++ b/src/text_recognizer/networks/loss.py
@@ -4,6 +4,9 @@ from torch import nn
from torch import Tensor
+__all__ = ["EmbeddingLoss"]
+
+
class EmbeddingLoss:
"""Metric loss for training encoders to produce information-rich latent embeddings."""
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
index d66af28..1101912 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/src/text_recognizer/networks/mlp.py
@@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange
import torch
from torch import nn
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
class MLP(nn.Module):
@@ -63,8 +63,8 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The feedforward pass."""
# If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
return self.layers(x)
@property
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index 046600d..6405192 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -7,8 +7,8 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
from text_recognizer.networks.stn import SpatialTransformerNetwork
+from text_recognizer.networks.util import activation_function
class Conv2dAuto(nn.Conv2d):
@@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module):
in_channels=in_channels,
out_channels=self.block_sizes[0],
kernel_size=3,
- stride=2,
- padding=3,
+ stride=1,
+ padding=1,
bias=False,
),
nn.BatchNorm2d(self.block_sizes[0]),
diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py
new file mode 100644
index 0000000..53cf166
--- /dev/null
+++ b/src/text_recognizer/networks/sparse_mlp.py
@@ -0,0 +1,78 @@
+"""Defines the Sparse MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+import warnings
+
+from einops.layers.torch import Rearrange
+from pytorch_block_sparse import BlockSparseLinear
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+
+class SparseMLP(nn.Module):
+ """Sparse multi layered perceptron network."""
+
+ def __init__(
+ self,
+ input_size: int = 784,
+ num_classes: int = 10,
+ hidden_size: Union[int, List] = 128,
+ num_layers: int = 3,
+ density: float = 0.1,
+ activation_fn: str = "relu",
+ ) -> None:
+ """Initialization of the MLP network.
+
+ Args:
+ input_size (int): The input shape of the network. Defaults to 784.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
+ hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+ num_layers (int): The number of hidden layers. Defaults to 3.
+ density (float): The density of activation at each layer. Default to 0.1.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
+
+ """
+ super().__init__()
+
+ activation_fn = activation_function(activation_fn)
+
+ if isinstance(hidden_size, int):
+ hidden_size = [hidden_size] * num_layers
+
+ self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
+ nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+ activation_fn,
+ ]
+
+ for i in range(num_layers - 1):
+ self.layers += [
+ BlockSparseLinear(
+ in_features=hidden_size[i],
+ out_features=hidden_size[i + 1],
+ density=density,
+ ),
+ activation_fn,
+ ]
+
+ self.layers.append(
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+ )
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.layers(x)
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the network."""
+ return "mlp"
diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py
deleted file mode 100644
index c091ba0..0000000
--- a/src/text_recognizer/networks/transformer.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""TBC."""
-from typing import Dict
-
-import torch
-from torch import Tensor
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..020a917
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """Implementation of multihead attention."""
+
+ def __init__(
+ self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.fc_q = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_k = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_v = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+ self._init_weights()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(
+ self.fc_q.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_k.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_v.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.xavier_normal_(self.fc_out.weight)
+
+ def scaled_dot_product_attention(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ """Calculates the scaled dot product attention."""
+
+ # Compute the energy.
+ energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+ query.shape[-1]
+ )
+
+ # If we have a mask for padding some inputs.
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -np.inf)
+
+ # Compute the attention from the energy.
+ attention = torch.softmax(energy, dim=3)
+
+ out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+ out = rearrange(out, "b head l v -> b l (head v)")
+ return out, attention
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Forward pass for computing the multihead attention."""
+ # Get the query, key, and value tensor.
+ query = rearrange(
+ self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+ )
+ key = rearrange(
+ self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+ )
+ value = rearrange(
+ self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ )
+
+ out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+ out = self.fc_out(out)
+ out = self.dropout(out)
+ return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..a47141b
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,31 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py
new file mode 100644
index 0000000..8c391c8
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/sparse_transformer.py
@@ -0,0 +1 @@
+"""Encoder and Decoder modules using spares activations."""
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..1c9c7dd
--- /dev/null
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,241 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+ """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+ def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+ return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.layer = nn.Sequential(
+ nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
+ activation_function(activation),
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+ """Transfomer encoding layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through the encoder."""
+ # First block.
+ # Multi head attention.
+ out, _ = self.self_attention(src, src, src, mask)
+
+ # Add & norm.
+ out = self.block1(out, src)
+
+ # Second block.
+ # Apply 1D-convolution.
+ cnn_out = self.cnn(out)
+
+ # Add & norm.
+ out = self.block2(cnn_out, out)
+
+ return out
+
+
+class Encoder(nn.Module):
+ """Transfomer encoder module."""
+
+ def __init__(
+ self,
+ num_layers: int,
+ encoder_layer: Type[nn.Module],
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.norm = norm
+
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through all encoder layers."""
+ for layer in self.layers:
+ src = layer(src, src_mask)
+
+ if self.norm is not None:
+ src = self.norm(src)
+
+ return src
+
+
+class DecoderLayer(nn.Module):
+ """Transfomer decoder layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float = 0.0,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.multihead_attention = MultiHeadAttention(
+ hidden_dim, num_heads, dropout_rate
+ )
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass of the layer."""
+ out, _ = self.self_attention(trg, trg, trg, trg_mask)
+ trg = self.block1(out, trg)
+
+ out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+ trg = self.block2(out, trg)
+
+ out = self.cnn(trg)
+ out = self.block3(out, trg)
+
+ return out
+
+
+class Decoder(nn.Module):
+ """Transfomer decoder module."""
+
+ def __init__(
+ self,
+ decoder_layer: Type[nn.Module],
+ num_layers: int,
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the decoder."""
+ for layer in self.layers:
+ trg = layer(trg, memory, trg_mask, memory_mask)
+
+ if self.norm is not None:
+ trg = self.norm(trg)
+
+ return trg
+
+
+class Transformer(nn.Module):
+ """Transformer network."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+
+ # Configure encoder.
+ encoder_norm = nn.LayerNorm(hidden_dim)
+ encoder_layer = EncoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+ # Configure decoder.
+ decoder_norm = nn.LayerNorm(hidden_dim)
+ decoder_layer = DecoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src: Tensor,
+ trg: Tensor,
+ src_mask: Optional[Tensor] = None,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the transformer."""
+ if src.shape[0] != trg.shape[0]:
+ raise RuntimeError("The batch size of the src and trg must be the same.")
+ if src.shape[2] != trg.shape[2]:
+ raise RuntimeError(
+ "The number of features for the src and trg must be the same."
+ )
+
+ memory = self.encoder(src, src_mask)
+ output = self.decoder(trg, memory, trg_mask, memory_mask)
+ return output
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/util.py
index 1f853e9..0d08506 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/util.py
@@ -1,7 +1,10 @@
"""Miscellaneous neural network functionality."""
-from typing import Tuple, Type
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
from einops import rearrange
+from loguru import logger
import torch
from torch import nn
@@ -43,3 +46,38 @@ def activation_function(activation: str) -> Type[nn.Module]:
]
)
return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+ """Loads a backbone network."""
+ network_module = importlib.import_module("text_recognizer.networks")
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+ "pretrained"
+ )
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ else:
+ backbone_ = getattr(network_module, backbone)
+ backbone = backbone_(**backbone_args)
+
+ if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+ backbone = nn.Sequential(
+ *list(backbone.children())[0][: -backbone_args["remove_layers"]]
+ )
+
+ return backbone
diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py
new file mode 100644
index 0000000..4d204d3
--- /dev/null
+++ b/src/text_recognizer/networks/vision_transformer.py
@@ -0,0 +1,158 @@
+"""VisionTransformer module.
+
+Splits each image into patches and feeds them to a transformer.
+
+"""
+
+from typing import Dict, Optional, Tuple, Type
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import configure_backbone
+
+
+class VisionTransformer(nn.Module):
+ """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ max_len: int,
+ expansion_dim: int,
+ mlp_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ activation: str = "gelu",
+ backbone: Optional[str] = None,
+ backbone_args: Optional[Dict] = None,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.stride = stride
+ self.trg_pad_index = trg_pad_index
+ self.slidning_window = self._configure_sliding_window()
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len)
+
+ self.use_backbone = False
+ if backbone is None:
+ self.linear_projection = nn.Linear(
+ self.patch_size[0] * self.patch_size[1], hidden_dim
+ )
+ else:
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.use_backbone = True
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(
+ nn.LayerNorm(hidden_dim),
+ nn.Linear(hidden_dim, mlp_dim),
+ nn.GELU(),
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(mlp_dim, vocab_size),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def _backbone(self, x: Tensor) -> Tensor:
+ b, t = x.shape[:2]
+ if self.use_backbone:
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+ x = self.backbone(x)
+ x = rearrange(x, "(b t) h -> b t h", b=b, t=t)
+ else:
+ x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t)
+ x = self.linear_projection(x)
+ return x
+
+ def preprocess_input(self, src: Tensor) -> Tensor:
+ """Encodes src with a backbone network and a positional encoding.
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimenstion is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src = self.slidning_window(src) # .squeeze(-2)
+ src = self._backbone(src)
+ src = self.position_encoding(src)
+ return src
+
+ def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.character_embedding(trg.long())
+ trg = self.position_encoding(trg)
+ return trg, trg_mask
+
+ def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+ """Forward pass with vision transfomer."""
+ src = self.preprocess_input(x)
+ trg, trg_mask = self.preprocess_target(trg)
+ out = self.transformer(src, trg, trg_mask=trg_mask)
+ logits = self.head(out)
+ return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 618f414..aa79c12 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -8,7 +8,7 @@ import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.misc import activation_function
+from text_recognizer.networks.util import activation_function
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
@@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Feedforward pass."""
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
x = self.encoder(x)
if self.decoder is not None:
x = self.decoder(x)