diff options
Diffstat (limited to 'src/text_recognizer/networks')
18 files changed, 1025 insertions, 64 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index a39975f..8b87797 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,21 +1,31 @@ """Network modules.""" +from .cnn_transformer import CNNTransformer +from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder +from .densenet import DenseNet from .lenet import LeNet -from .line_lstm_ctc import LineRecurrentNetwork -from .losses import EmbeddingLoss -from .misc import sliding_window +from .loss import EmbeddingLoss from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .sparse_mlp import SparseMLP +from .transformer import Transformer +from .util import sliding_window +from .vision_transformer import VisionTransformer from .wide_resnet import WideResidualNetwork __all__ = [ + "CNNTransformer", + "ConvolutionalRecurrentNetwork", + "DenseNet", "EmbeddingLoss", "greedy_decoder", "MLP", "LeNet", - "LineRecurrentNetwork", "ResidualNetwork", "ResidualNetworkEncoder", "sliding_window", + "Transformer", + "SparseMLP", + "VisionTransformer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py new file mode 100644 index 0000000..8666f11 --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -0,0 +1,111 @@ +"""A DETR style transfomers but for text recognition.""" +from typing import Dict, Optional, Tuple, Type + +from einops.layers.torch import Rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformer(nn.Module): + """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + max_len: int, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + backbone: str, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + ) -> None: + super().__init__() + self.trg_pad_index = trg_pad_index + self.backbone_args = backbone_args + self.backbone = configure_backbone(backbone, backbone_args) + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + self.collapse_spatial_dim = nn.Sequential( + Rearrange("b t h w -> b t (h w)"), nn.AdaptiveAvgPool2d((None, hidden_dim)) + ) + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + self.head = nn.Linear(hidden_dim, vocab_size) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def preprocess_input(self, src: Tensor) -> Tensor: + """Encodes src with a backbone network and a positional encoding. + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimenstion is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src = self.backbone(src) + src = self.collapse_spatial_dim(src) + src = self.position_encoding(src) + return src + + def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + trg_mask = self._create_trg_mask(trg) + trg = self.character_embedding(trg.long()) + trg = self.position_encoding(trg) + return trg, trg_mask + + def forward(self, x: Tensor, trg: Tensor) -> Tensor: + """Forward pass with CNN transfomer.""" + src = self.preprocess_input(x) + trg, trg_mask = self.preprocess_target(trg) + out = self.transformer(src, trg, trg_mask=trg_mask) + logits = self.head(out) + return logits diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/crnn.py index 9009f94..3e605e2 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/crnn.py @@ -10,15 +10,16 @@ import torch from torch import nn from torch import Tensor +from text_recognizer.networks.util import configure_backbone -class LineRecurrentNetwork(nn.Module): + +class ConvolutionalRecurrentNetwork(nn.Module): """Network that takes a image of a text line and predicts tokens that are in the image.""" def __init__( self, backbone: str, backbone_args: Dict = None, - flatten: bool = True, input_size: int = 128, hidden_size: int = 128, bidirectional: bool = False, @@ -26,6 +27,7 @@ class LineRecurrentNetwork(nn.Module): num_classes: int = 80, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), + recurrent_cell: str = "lstm", ) -> None: super().__init__() self.backbone_args = backbone_args or {} @@ -34,17 +36,19 @@ class LineRecurrentNetwork(nn.Module): self.sliding_window = self._configure_sliding_window() self.input_size = input_size self.hidden_size = hidden_size - self.backbone = self._configure_backbone(backbone) + self.backbone = configure_backbone(backbone, backbone_args) self.bidirectional = bidirectional - self.flatten = flatten - if self.flatten: - self.fc = nn.Linear( - in_features=self.input_size, out_features=self.hidden_size + if recurrent_cell.upper() in ["LSTM", "GRU"]: + recurrent_cell = getattr(nn, recurrent_cell) + else: + logger.warning( + f"Option {recurrent_cell} not valid, defaulting to LSTM cell." ) + recurrent_cell = nn.LSTM - self.rnn = nn.LSTM( - input_size=self.hidden_size, + self.rnn = recurrent_cell( + input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=bidirectional, num_layers=num_layers, @@ -57,32 +61,6 @@ class LineRecurrentNetwork(nn.Module): nn.LogSoftmax(dim=2), ) - def _configure_backbone(self, backbone: str) -> Type[nn.Module]: - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in self.backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[ - 2 - ] / self.backbone_args.pop("pretrained") - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: - for params in backbone.parameters(): - params.requires_grad = False - - return backbone - else: - return backbone_(**self.backbone_args) - def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( nn.Unfold(kernel_size=self.patch_size, stride=self.stride), @@ -96,8 +74,8 @@ class LineRecurrentNetwork(nn.Module): def forward(self, x: Tensor) -> Tensor: """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] x = self.sliding_window(x) # Rearrange from a sequence of patches for feedforward network. @@ -106,11 +84,7 @@ class LineRecurrentNetwork(nn.Module): x = self.backbone(x) # Avgerage pooling. - x = ( - self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) - if self.flatten - else rearrange(x, "(b t) h -> t b h", b=b, t=t) - ) + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py new file mode 100644 index 0000000..d2aad60 --- /dev/null +++ b/src/text_recognizer/networks/densenet.py @@ -0,0 +1,225 @@ +"""Defines a Densely Connected Convolutional Networks in PyTorch. + +Sources: +https://arxiv.org/abs/1608.06993 +https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py + +""" +from typing import List, Optional, Union + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class _DenseLayer(nn.Module): + """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" + + def __init__( + self, + in_channels: int, + growth_rate: int, + bn_size: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.dense_layer = [ + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False, + ), + nn.BatchNorm2d(bn_size * growth_rate), + activation_fn, + nn.Conv2d( + in_channels=bn_size * growth_rate, + out_channels=growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ] + if dropout_rate: + self.dense_layer.append(nn.Dropout(p=dropout_rate)) + + self.dense_layer = nn.Sequential(*self.dense_layer) + + def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: + if isinstance(x, list): + x = torch.cat(x, 1) + return self.dense_layer(x) + + +class _DenseBlock(nn.Module): + def __init__( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.dense_block = self._build_dense_blocks( + num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation + ) + + def _build_dense_blocks( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> nn.ModuleList: + dense_block = [] + for i in range(num_layers): + dense_block.append( + _DenseLayer( + in_channels=in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + return nn.ModuleList(dense_block) + + def forward(self, x: Tensor) -> Tensor: + feature_maps = [x] + for layer in self.dense_block: + x = layer(feature_maps) + feature_maps.append(x) + return torch.cat(feature_maps, 1) + + +class _Transition(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.transition = nn.Sequential( + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + bias=False, + ), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.transition(x) + + +class DenseNet(nn.Module): + """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" + + def __init__( + self, + growth_rate: int = 32, + block_config: List[int] = (6, 12, 24, 16), + in_channels: int = 1, + base_channels: int = 64, + num_classes: int = 80, + bn_size: int = 4, + dropout_rate: float = 0, + classifier: bool = True, + activation: str = "relu", + ) -> None: + super().__init__() + self.densenet = self._configure_densenet( + in_channels, + base_channels, + num_classes, + growth_rate, + block_config, + bn_size, + dropout_rate, + classifier, + activation, + ) + + def _configure_densenet( + self, + in_channels: int, + base_channels: int, + num_classes: int, + growth_rate: int, + block_config: List[int], + bn_size: int, + dropout_rate: float, + classifier: bool, + activation: str, + ) -> nn.Sequential: + activation_fn = activation_function(activation) + densenet = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(base_channels), + activation_fn, + ] + + num_features = base_channels + + for i, num_layers in enumerate(block_config): + densenet.append( + _DenseBlock( + num_layers=num_layers, + in_channels=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + densenet.append( + _Transition( + in_channels=num_features, + out_channels=num_features // 2, + activation=activation, + ) + ) + num_features = num_features // 2 + + densenet.append(activation_fn) + + if classifier: + densenet.append(nn.AdaptiveAvgPool2d((1, 1))) + densenet.append(Rearrange("b c h w -> b (c h w)")) + densenet.append( + nn.Linear(in_features=num_features, out_features=num_classes) + ) + + return nn.Sequential(*densenet) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of Densenet.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.densenet(x) diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 53c575e..527e1a0 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange import torch from torch import nn -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function class LeNet(nn.Module): @@ -63,6 +63,6 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] return self.layers(x) diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/loss.py index 73e0641..ff843cf 100644 --- a/src/text_recognizer/networks/losses.py +++ b/src/text_recognizer/networks/loss.py @@ -4,6 +4,9 @@ from torch import nn from torch import Tensor +__all__ = ["EmbeddingLoss"] + + class EmbeddingLoss: """Metric loss for training encoders to produce information-rich latent embeddings.""" diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d66af28..1101912 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange import torch from torch import nn -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function class MLP(nn.Module): @@ -63,8 +63,8 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] return self.layers(x) @property diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 046600d..6405192 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,8 +7,8 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.misc import activation_function from text_recognizer.networks.stn import SpatialTransformerNetwork +from text_recognizer.networks.util import activation_function class Conv2dAuto(nn.Conv2d): @@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module): in_channels=in_channels, out_channels=self.block_sizes[0], kernel_size=3, - stride=2, - padding=3, + stride=1, + padding=1, bias=False, ), nn.BatchNorm2d(self.block_sizes[0]), diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py new file mode 100644 index 0000000..53cf166 --- /dev/null +++ b/src/text_recognizer/networks/sparse_mlp.py @@ -0,0 +1,78 @@ +"""Defines the Sparse MLP network.""" +from typing import Callable, Dict, List, Optional, Union +import warnings + +from einops.layers.torch import Rearrange +from pytorch_block_sparse import BlockSparseLinear +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +class SparseMLP(nn.Module): + """Sparse multi layered perceptron network.""" + + def __init__( + self, + input_size: int = 784, + num_classes: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + density: float = 0.1, + activation_fn: str = "relu", + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. Defaults to 784. + num_classes (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + density (float): The density of activation at each layer. Default to 0.1. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. + + """ + super().__init__() + + activation_fn = activation_function(activation_fn) + + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + + self.layers = [ + Rearrange("b c h w -> b (c h w)"), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), + activation_fn, + ] + + for i in range(num_layers - 1): + self.layers += [ + BlockSparseLinear( + in_features=hidden_size[i], + out_features=hidden_size[i + 1], + density=density, + ), + activation_fn, + ] + + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) + ) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.layers(x) + + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py deleted file mode 100644 index c091ba0..0000000 --- a/src/text_recognizer/networks/transformer.py +++ /dev/null @@ -1,5 +0,0 @@ -"""TBC.""" -from typing import Dict - -import torch -from torch import Tensor diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py new file mode 100644 index 0000000..020a917 --- /dev/null +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -0,0 +1,3 @@ +"""Transformer modules.""" +from .positional_encoding import PositionalEncoding +from .transformer import Decoder, Encoder, Transformer diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py new file mode 100644 index 0000000..cce1ecc --- /dev/null +++ b/src/text_recognizer/networks/transformer/attention.py @@ -0,0 +1,93 @@ +"""Implementes the attention module for the transformer.""" +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class MultiHeadAttention(nn.Module): + """Implementation of multihead attention.""" + + def __init__( + self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.fc_q = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_k = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_v = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) + + self._init_weights() + + self.dropout = nn.Dropout(p=dropout_rate) + + def _init_weights(self) -> None: + nn.init.normal_( + self.fc_q.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_k.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_v.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.xavier_normal_(self.fc_out.weight) + + def scaled_dot_product_attention( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + """Calculates the scaled dot product attention.""" + + # Compute the energy. + energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( + query.shape[-1] + ) + + # If we have a mask for padding some inputs. + if mask is not None: + energy = energy.masked_fill(mask == 0, -np.inf) + + # Compute the attention from the energy. + attention = torch.softmax(energy, dim=3) + + out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) + out = rearrange(out, "b head l v -> b l (head v)") + return out, attention + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + """Forward pass for computing the multihead attention.""" + # Get the query, key, and value tensor. + query = rearrange( + self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads + ) + key = rearrange( + self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads + ) + value = rearrange( + self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + ) + + out, attention = self.scaled_dot_product_attention(query, key, value, mask) + + out = self.fc_out(out) + out = self.dropout(out) + return out, attention diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py new file mode 100644 index 0000000..a47141b --- /dev/null +++ b/src/text_recognizer/networks/transformer/positional_encoding.py @@ -0,0 +1,31 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): + """Encodes a sense of distance or time for transformer networks.""" + + def __init__( + self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 + ) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """Encodes the tensor with a postional embedding.""" + x = x + self.pe[:, : x.shape[1]] + return self.dropout(x) diff --git a/src/text_recognizer/networks/transformer/sparse_transformer.py b/src/text_recognizer/networks/transformer/sparse_transformer.py new file mode 100644 index 0000000..8c391c8 --- /dev/null +++ b/src/text_recognizer/networks/transformer/sparse_transformer.py @@ -0,0 +1 @@ +"""Encoder and Decoder modules using spares activations.""" diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..1c9c7dd --- /dev/null +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,241 @@ +"""Transfomer module.""" +import copy +from typing import Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer.attention import MultiHeadAttention +from text_recognizer.networks.util import activation_function + + +def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) + + +class _IntraLayerConnection(nn.Module): + """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" + + def __init__(self, dropout_rate: float, hidden_dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(normalized_shape=hidden_dim) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, src: Tensor, residual: Tensor) -> Tensor: + return self.norm(self.dropout(src) + residual) + + +class _ConvolutionalLayer(nn.Module): + def __init__( + self, + hidden_dim: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.layer = nn.Sequential( + nn.Linear(in_features=hidden_dim, out_features=expansion_dim), + activation_function(activation), + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=expansion_dim, out_features=hidden_dim), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layer(x) + + +class EncoderLayer(nn.Module): + """Transfomer encoding layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through the encoder.""" + # First block. + # Multi head attention. + out, _ = self.self_attention(src, src, src, mask) + + # Add & norm. + out = self.block1(out, src) + + # Second block. + # Apply 1D-convolution. + cnn_out = self.cnn(out) + + # Add & norm. + out = self.block2(cnn_out, out) + + return out + + +class Encoder(nn.Module): + """Transfomer encoder module.""" + + def __init__( + self, + num_layers: int, + encoder_layer: Type[nn.Module], + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.norm = norm + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through all encoder layers.""" + for layer in self.layers: + src = layer(src, src_mask) + + if self.norm is not None: + src = self.norm(src) + + return src + + +class DecoderLayer(nn.Module): + """Transfomer decoder layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float = 0.0, + activation: str = "relu", + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.multihead_attention = MultiHeadAttention( + hidden_dim, num_heads, dropout_rate + ) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass of the layer.""" + out, _ = self.self_attention(trg, trg, trg, trg_mask) + trg = self.block1(out, trg) + + out, _ = self.multihead_attention(trg, memory, memory, memory_mask) + trg = self.block2(out, trg) + + out = self.cnn(trg) + out = self.block3(out, trg) + + return out + + +class Decoder(nn.Module): + """Transfomer decoder module.""" + + def __init__( + self, + decoder_layer: Type[nn.Module], + num_layers: int, + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the decoder.""" + for layer in self.layers: + trg = layer(trg, memory, trg_mask, memory_mask) + + if self.norm is not None: + trg = self.norm(trg) + + return trg + + +class Transformer(nn.Module): + """Transformer network.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + # Configure encoder. + encoder_norm = nn.LayerNorm(hidden_dim) + encoder_layer = EncoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) + + # Configure decoder. + decoder_norm = nn.LayerNorm(hidden_dim) + decoder_layer = DecoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src: Tensor, + trg: Tensor, + src_mask: Optional[Tensor] = None, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the transformer.""" + if src.shape[0] != trg.shape[0]: + raise RuntimeError("The batch size of the src and trg must be the same.") + if src.shape[2] != trg.shape[2]: + raise RuntimeError( + "The number of features for the src and trg must be the same." + ) + + memory = self.encoder(src, src_mask) + output = self.decoder(trg, memory, trg_mask, memory_mask) + return output diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/util.py index 1f853e9..0d08506 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/util.py @@ -1,7 +1,10 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple, Type +import importlib +from pathlib import Path +from typing import Dict, Tuple, Type from einops import rearrange +from loguru import logger import torch from torch import nn @@ -43,3 +46,38 @@ def activation_function(activation: str) -> Type[nn.Module]: ] ) return activation_fns[activation.lower()] + + +def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: + """Loads a backbone network.""" + network_module = importlib.import_module("text_recognizer.networks") + backbone_ = getattr(network_module, backbone) + + if "pretrained" in backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( + "pretrained" + ) + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if "freeze" in backbone_args and backbone_args["freeze"] is True: + for params in backbone.parameters(): + params.requires_grad = False + + else: + backbone_ = getattr(network_module, backbone) + backbone = backbone_(**backbone_args) + + if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: + backbone = nn.Sequential( + *list(backbone.children())[0][: -backbone_args["remove_layers"]] + ) + + return backbone diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py new file mode 100644 index 0000000..4d204d3 --- /dev/null +++ b/src/text_recognizer/networks/vision_transformer.py @@ -0,0 +1,158 @@ +"""VisionTransformer module. + +Splits each image into patches and feeds them to a transformer. + +""" + +from typing import Dict, Optional, Tuple, Type + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import configure_backbone + + +class VisionTransformer(nn.Module): + """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + max_len: int, + expansion_dim: int, + mlp_dim: int, + dropout_rate: float, + trg_pad_index: int, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + activation: str = "gelu", + backbone: Optional[str] = None, + backbone_args: Optional[Dict] = None, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.stride = stride + self.trg_pad_index = trg_pad_index + self.slidning_window = self._configure_sliding_window() + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + + self.use_backbone = False + if backbone is None: + self.linear_projection = nn.Linear( + self.patch_size[0] * self.patch_size[1], hidden_dim + ) + else: + self.backbone = configure_backbone(backbone, backbone_args) + self.use_backbone = True + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, mlp_dim), + nn.GELU(), + nn.Dropout(p=dropout_rate), + nn.Linear(mlp_dim, vocab_size), + ) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def _backbone(self, x: Tensor) -> Tensor: + b, t = x.shape[:2] + if self.use_backbone: + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + x = self.backbone(x) + x = rearrange(x, "(b t) h -> b t h", b=b, t=t) + else: + x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) + x = self.linear_projection(x) + return x + + def preprocess_input(self, src: Tensor) -> Tensor: + """Encodes src with a backbone network and a positional encoding. + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimenstion is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src = self.slidning_window(src) # .squeeze(-2) + src = self._backbone(src) + src = self.position_encoding(src) + return src + + def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + trg_mask = self._create_trg_mask(trg) + trg = self.character_embedding(trg.long()) + trg = self.position_encoding(trg) + return trg, trg_mask + + def forward(self, x: Tensor, trg: Tensor) -> Tensor: + """Forward pass with vision transfomer.""" + src = self.preprocess_input(x) + trg, trg_mask = self.preprocess_target(trg) + out = self.transformer(src, trg, trg_mask=trg_mask) + logits = self.head(out) + return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index 618f414..aa79c12 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -8,7 +8,7 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: @@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module): def forward(self, x: Tensor) -> Tensor: """Feedforward pass.""" - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * int(4 - len(x.shape))] x = self.encoder(x) if self.decoder is not None: x = self.decoder(x) |