From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- src/text_recognizer/networks/__init__.py | 43 --- src/text_recognizer/networks/beam.py | 83 ----- src/text_recognizer/networks/cnn.py | 101 ----- src/text_recognizer/networks/cnn_transformer.py | 158 -------- src/text_recognizer/networks/crnn.py | 110 ------ src/text_recognizer/networks/ctc.py | 58 --- src/text_recognizer/networks/densenet.py | 225 ----------- src/text_recognizer/networks/lenet.py | 68 ---- src/text_recognizer/networks/loss/__init__.py | 2 - src/text_recognizer/networks/loss/loss.py | 69 ---- src/text_recognizer/networks/metrics.py | 123 ------- src/text_recognizer/networks/mlp.py | 73 ---- src/text_recognizer/networks/residual_network.py | 310 ---------------- src/text_recognizer/networks/stn.py | 44 --- .../networks/transducer/__init__.py | 3 - .../networks/transducer/tds_conv.py | 208 ----------- src/text_recognizer/networks/transducer/test.py | 60 --- .../networks/transducer/transducer.py | 410 --------------------- .../networks/transformer/__init__.py | 3 - .../networks/transformer/attention.py | 93 ----- .../networks/transformer/positional_encoding.py | 32 -- .../networks/transformer/transformer.py | 264 ------------- src/text_recognizer/networks/unet.py | 255 ------------- src/text_recognizer/networks/util.py | 89 ----- src/text_recognizer/networks/vit.py | 150 -------- src/text_recognizer/networks/vq_transformer.py | 150 -------- src/text_recognizer/networks/vqvae/__init__.py | 5 - src/text_recognizer/networks/vqvae/decoder.py | 133 ------- src/text_recognizer/networks/vqvae/encoder.py | 147 -------- .../networks/vqvae/vector_quantizer.py | 119 ------ src/text_recognizer/networks/vqvae/vqvae.py | 74 ---- src/text_recognizer/networks/wide_resnet.py | 221 ----------- 32 files changed, 3883 deletions(-) delete mode 100644 src/text_recognizer/networks/__init__.py delete mode 100644 src/text_recognizer/networks/beam.py delete mode 100644 src/text_recognizer/networks/cnn.py delete mode 100644 src/text_recognizer/networks/cnn_transformer.py delete mode 100644 src/text_recognizer/networks/crnn.py delete mode 100644 src/text_recognizer/networks/ctc.py delete mode 100644 src/text_recognizer/networks/densenet.py delete mode 100644 src/text_recognizer/networks/lenet.py delete mode 100644 src/text_recognizer/networks/loss/__init__.py delete mode 100644 src/text_recognizer/networks/loss/loss.py delete mode 100644 src/text_recognizer/networks/metrics.py delete mode 100644 src/text_recognizer/networks/mlp.py delete mode 100644 src/text_recognizer/networks/residual_network.py delete mode 100644 src/text_recognizer/networks/stn.py delete mode 100644 src/text_recognizer/networks/transducer/__init__.py delete mode 100644 src/text_recognizer/networks/transducer/tds_conv.py delete mode 100644 src/text_recognizer/networks/transducer/test.py delete mode 100644 src/text_recognizer/networks/transducer/transducer.py delete mode 100644 src/text_recognizer/networks/transformer/__init__.py delete mode 100644 src/text_recognizer/networks/transformer/attention.py delete mode 100644 src/text_recognizer/networks/transformer/positional_encoding.py delete mode 100644 src/text_recognizer/networks/transformer/transformer.py delete mode 100644 src/text_recognizer/networks/unet.py delete mode 100644 src/text_recognizer/networks/util.py delete mode 100644 src/text_recognizer/networks/vit.py delete mode 100644 src/text_recognizer/networks/vq_transformer.py delete mode 100644 src/text_recognizer/networks/vqvae/__init__.py delete mode 100644 src/text_recognizer/networks/vqvae/decoder.py delete mode 100644 src/text_recognizer/networks/vqvae/encoder.py delete mode 100644 src/text_recognizer/networks/vqvae/vector_quantizer.py delete mode 100644 src/text_recognizer/networks/vqvae/vqvae.py delete mode 100644 src/text_recognizer/networks/wide_resnet.py (limited to 'src/text_recognizer/networks') diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py deleted file mode 100644 index 1521355..0000000 --- a/src/text_recognizer/networks/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Network modules.""" -from .cnn import CNN -from .cnn_transformer import CNNTransformer -from .crnn import ConvolutionalRecurrentNetwork -from .ctc import greedy_decoder -from .densenet import DenseNet -from .lenet import LeNet -from .metrics import accuracy, cer, wer -from .mlp import MLP -from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import load_transducer_loss, TDS2d -from .transformer import Transformer -from .unet import UNet -from .util import sliding_window -from .vit import ViT -from .vq_transformer import VQTransformer -from .vqvae import VQVAE -from .wide_resnet import WideResidualNetwork - -__all__ = [ - "accuracy", - "cer", - "CNN", - "CNNTransformer", - "ConvolutionalRecurrentNetwork", - "DenseNet", - "FCN", - "greedy_decoder", - "MLP", - "LeNet", - "load_transducer_loss", - "ResidualNetwork", - "ResidualNetworkEncoder", - "sliding_window", - "UNet", - "TDS2d", - "Transformer", - "ViT", - "VQTransformer", - "VQVAE", - "wer", - "WideResidualNetwork", -] diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py deleted file mode 100644 index dccccdb..0000000 --- a/src/text_recognizer/networks/beam.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Implementation of beam search decoder for a sequence to sequence network. - -Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py - -""" -# from typing import List -# from Queue import PriorityQueue - -# from loguru import logger -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F - - -# class Node: -# def __init__( -# self, parent: Node, target_index: int, log_prob: Tensor, length: int -# ) -> None: -# self.parent = parent -# self.target_index = target_index -# self.log_prob = log_prob -# self.length = length -# self.reward = 0.0 - -# def eval(self, alpha: float = 1.0) -> Tensor: -# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward - - -# @torch.no_grad() -# def beam_decoder( -# network, mapper, device, memory: Tensor = None, max_len: int = 97, -# ) -> Tensor: -# beam_width = 10 -# topk = 1 # How many sentences to generate. - -# trg_indices = [mapper(mapper.init_token)] - -# end_nodes = [] - -# node = Node(None, trg_indices, 0, 1) -# nodes = PriorityQueue() - -# nodes.put((node.eval(), node)) -# q_size = 1 - -# # Beam search -# for _ in range(max_len): -# if q_size > 2000: -# logger.warning("Could not decoder input") -# break - -# # Fetch the best node. -# score, n = nodes.get() -# decoder_input = n.target_index - -# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: -# end_nodes.append((score, n)) - -# # If we reached the maximum number of sentences required. -# if len(end_nodes) >= 1: -# break -# else: -# continue - -# # Forward pass with transformer. -# trg = torch.tensor(trg_indices, device=device)[None, :].long() -# trg = network.target_embedding(trg) -# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) -# log_prob = F.log_softmax(logits, dim=2) - -# log_prob, indices = torch.topk(log_prob, beam_width) - -# for new_k in range(beam_width): -# # TODO: continue from here -# token_index = indices[0][new_k].view(1, -1) -# log_p = log_prob[0][new_k].item() - -# node = Node() - -# pass - -# pass diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py deleted file mode 100644 index 1807bb9..0000000 --- a/src/text_recognizer/networks/cnn.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Implementation of a simple backbone cnn network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class CNN(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64, 128), - kernel_sizes: Tuple[int, ...] = (4, 4, 4), - strides: Tuple[int, ...] = (2, 2, 2), - max_pool_kernel: int = 2, - dropout_rate: float = 0.2, - activation: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). - max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. - dropout_rate (float): The dropout rate. Defaults to 0.2. - activation (Optional[str]): The name of non-linear activation function. Defaults to relu. - - Raises: - RuntimeError: if the number of hyperparameters does not match in length. - - """ - super().__init__() - - if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): - raise RuntimeError("The number of the hyperparameters does not match.") - - self.cnn = self._build_network( - channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, - ) - - def _build_network( - self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - strides: Tuple[int, ...], - max_pool_kernel: int, - dropout_rate: float, - activation: str, - ) -> nn.Sequential: - # Load activation function. - activation_fn = activation_function(activation) - - channels = list(channels) - in_channels = channels.pop(0) - configuration = zip(channels, kernel_sizes, strides) - - modules = nn.ModuleList([]) - - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - # Add max pool to reduce output size. - if i == len(channels) // 2: - modules.append(nn.MaxPool2d(max_pool_kernel)) - if i == 0: - modules.append( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ) - ) - else: - modules.append( - nn.Sequential( - activation_fn, - nn.BatchNorm2d(in_channels), - nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - ) - ) - - if dropout_rate: - modules.append(nn.Dropout2d(p=dropout_rate)) - - in_channels = out_channels - - return nn.Sequential(*modules) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.cnn(x) diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py deleted file mode 100644 index a2d7926..0000000 --- a/src/text_recognizer/networks/cnn_transformer.py +++ /dev/null @@ -1,158 +0,0 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - pool_kernel: Optional[Tuple[int, int]] = None, - ) -> None: - super().__init__() - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.backbone = configure_backbone(backbone, backbone_args) - - if pool_kernel is not None: - self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) - else: - self.max_pool = None - - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.pos_dropout = nn.Dropout(p=dropout_rate) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None - ) - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential( - # nn.Linear(hidden_dim, hidden_dim * 2), - # activation_function(activation), - nn.Linear(hidden_dim, vocab_size), - ) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - src = self.backbone(src) - - if self.max_pool is not None: - src = self.max_pool(src) - - if self.adaptive_pool is not None and len(src.shape) == 4: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - elif len(src.shape) == 4: - src = rearrange(src, "b c h w -> b (h w) c") - - b, t, _ = src.shape - - src += self.src_position_embedding[:, :t] - src = self.pos_dropout(src) - - return src - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - trg = self.character_embedding(trg.long()) - trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py deleted file mode 100644 index 778e232..0000000 --- a/src/text_recognizer/networks/crnn.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CRNN for handwritten text recognition.""" -from typing import Dict, Tuple - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import configure_backbone - - -class ConvolutionalRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - recurrent_cell: str = "lstm", - avg_pool: bool = False, - use_sliding_window: bool = True, - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = ( - self._configure_sliding_window() if use_sliding_window else None - ) - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = configure_backbone(backbone, backbone_args) - self.bidirectional = bidirectional - self.avg_pool = avg_pool - - if recurrent_cell.upper() in ["LSTM", "GRU"]: - recurrent_cell = getattr(nn, recurrent_cell) - else: - logger.warning( - f"Option {recurrent_cell} not valid, defaulting to LSTM cell." - ) - recurrent_cell = nn.LSTM - - self.rnn = recurrent_cell( - input_size=self.input_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - if self.sliding_window is not None: - # Create image patches with a sliding window kernel. - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - - x = self.backbone(x) - - # Average pooling. - if self.avg_pool: - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) - else: - x = rearrange(x, "(b t) h -> t b h", b=b, t=t) - else: - # Encode the entire image with a CNN, and use the channels as temporal dimension. - x = self.backbone(x) - x = rearrange(x, "b c h w -> b w c h") - if self.adaptive_pool is not None: - x = self.adaptive_pool(x) - x = x.squeeze(3) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classification layer. - x = self.decoder(x) - return x diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py deleted file mode 100644 index af9b700..0000000 --- a/src/text_recognizer/networks/ctc.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decodes the CTC output.""" -from typing import Callable, List, Optional, Tuple - -from einops import rearrange -import torch -from torch import Tensor - -from text_recognizer.datasets.util import EmnistMapper - - -def greedy_decoder( - predictions: Tensor, - targets: Optional[Tensor] = None, - target_lengths: Optional[Tensor] = None, - character_mapper: Optional[Callable] = None, - blank_label: int = 79, - collapse_repeated: bool = True, -) -> Tuple[List[str], List[str]]: - """Greedy CTC decoder. - - Args: - predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. - targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. - target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. - character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults - to None. - blank_label (int): The blank character to be ignored. Defaults to 80. - collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. - - Returns: - Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. - - """ - - if character_mapper is None: - character_mapper = EmnistMapper(pad_token="_") # noqa: S106 - - predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") - decoded_predictions = [] - decoded_targets = [] - for i, prediction in enumerate(predictions): - decoded_prediction = [] - decoded_target = [] - if targets is not None and target_lengths is not None: - for target_index in targets[i][: target_lengths[i]]: - if target_index == blank_label: - continue - decoded_target.append(character_mapper(int(target_index))) - decoded_targets.append(decoded_target) - for j, index in enumerate(prediction): - if index != blank_label: - if collapse_repeated and j != 0 and index == prediction[j - 1]: - continue - decoded_prediction.append(index.item()) - decoded_predictions.append( - [character_mapper(int(pred_index)) for pred_index in decoded_prediction] - ) - return decoded_predictions, decoded_targets diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/src/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): - """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - - def __init__( - self, - in_channels: int, - growth_rate: int, - bn_size: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.dense_layer = [ - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=bn_size * growth_rate, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm2d(bn_size * growth_rate), - activation_fn, - nn.Conv2d( - in_channels=bn_size * growth_rate, - out_channels=growth_rate, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - ] - if dropout_rate: - self.dense_layer.append(nn.Dropout(p=dropout_rate)) - - self.dense_layer = nn.Sequential(*self.dense_layer) - - def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: - if isinstance(x, list): - x = torch.cat(x, 1) - return self.dense_layer(x) - - -class _DenseBlock(nn.Module): - def __init__( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, - ) - - def _build_dense_blocks( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> nn.ModuleList: - dense_block = [] - for i in range(num_layers): - dense_block.append( - _DenseLayer( - in_channels=in_channels + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - return nn.ModuleList(dense_block) - - def forward(self, x: Tensor) -> Tensor: - feature_maps = [x] - for layer in self.dense_block: - x = layer(feature_maps) - feature_maps.append(x) - return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.transition = nn.Sequential( - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - bias=False, - ), - nn.AvgPool2d(kernel_size=2, stride=2), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.transition(x) - - -class DenseNet(nn.Module): - """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - - def __init__( - self, - growth_rate: int = 32, - block_config: List[int] = (6, 12, 24, 16), - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 80, - bn_size: int = 4, - dropout_rate: float = 0, - classifier: bool = True, - activation: str = "relu", - ) -> None: - super().__init__() - self.densenet = self._configure_densenet( - in_channels, - base_channels, - num_classes, - growth_rate, - block_config, - bn_size, - dropout_rate, - classifier, - activation, - ) - - def _configure_densenet( - self, - in_channels: int, - base_channels: int, - num_classes: int, - growth_rate: int, - block_config: List[int], - bn_size: int, - dropout_rate: float, - classifier: bool, - activation: str, - ) -> nn.Sequential: - activation_fn = activation_function(activation) - densenet = [ - nn.Conv2d( - in_channels=in_channels, - out_channels=base_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(base_channels), - activation_fn, - ] - - num_features = base_channels - - for i, num_layers in enumerate(block_config): - densenet.append( - _DenseBlock( - num_layers=num_layers, - in_channels=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - densenet.append( - _Transition( - in_channels=num_features, - out_channels=num_features // 2, - activation=activation, - ) - ) - num_features = num_features // 2 - - densenet.append(activation_fn) - - if classifier: - densenet.append(nn.AdaptiveAvgPool2d((1, 1))) - densenet.append(Rearrange("b c h w -> b (c h w)")) - densenet.append( - nn.Linear(in_features=num_features, out_features=num_classes) - ) - - return nn.Sequential(*densenet) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass of Densenet.""" - # If batch dimenstion is missing, it will be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.densenet(x) diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py deleted file mode 100644 index 527e1a0..0000000 --- a/src/text_recognizer/networks/lenet.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Implementation of the LeNet network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class LeNet(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64), - kernel_sizes: Tuple[int, ...] = (3, 3, 2), - hidden_size: Tuple[int, ...] = (9216, 128), - dropout_rate: float = 0.2, - num_classes: int = 10, - activation_fn: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - Defaults to (9216, 128). - dropout_rate (float): The dropout rate. Defaults to 0.2. - num_classes (int): Number of classes. Defaults to 10. - activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - self.layers = [ - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - ), - activation_fn, - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[2], - kernel_size=kernel_sizes[1], - ), - activation_fn, - nn.MaxPool2d(kernel_sizes[2]), - nn.Dropout(p=dropout_rate), - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), - activation_fn, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=num_classes), - ] - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index b489264..0000000 --- a/src/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/src/text_recognizer/networks/loss/loss.py b/src/text_recognizer/networks/loss/loss.py deleted file mode 100644 index cf9fa0d..0000000 --- a/src/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers -import torch -from torch import nn -from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F - -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py deleted file mode 100644 index 2605731..0000000 --- a/src/text_recognizer/networks/metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Utility functions for models.""" -from typing import Optional - -from einops import rearrange -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: - """Computes the accuracy. - - Args: - outputs (Tensor): The output from the network. - labels (Tensor): Ground truth labels. - pad_index (int): Padding index. - - Returns: - float: The accuracy for the batch. - - """ - - _, predicted = torch.max(outputs, dim=-1) - - # Mask out the pad tokens - mask = labels != pad_index - - predicted *= mask - labels *= mask - - acc = (predicted == labels).sum().float() / labels.shape[0] - acc = acc.item() - return acc - - -def cer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the character error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (Optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The cer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - prediction, target = ( - prediction.replace(" ", ""), - target.replace(" ", ""), - ) - lev_dist += Lev.distance(prediction, target) - return lev_dist / len(decoded_predictions) - - -def wer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the Word error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The wer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - - b = set(prediction.split() + target.split()) - word2char = dict(zip(b, range(len(b)))) - - w1 = [chr(word2char[w]) for w in prediction.split()] - w2 = [chr(word2char[w]) for w in target.split()] - - lev_dist += Lev.distance("".join(w1), "".join(w2)) - - return lev_dist / len(decoded_predictions) diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py deleted file mode 100644 index 1101912..0000000 --- a/src/text_recognizer/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Defines the MLP network.""" -from typing import Callable, Dict, List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class MLP(nn.Module): - """Multi layered perceptron network.""" - - def __init__( - self, - input_size: int = 784, - num_classes: int = 10, - hidden_size: Union[int, List] = 128, - num_layers: int = 3, - dropout_rate: float = 0.2, - activation_fn: str = "relu", - ) -> None: - """Initialization of the MLP network. - - Args: - input_size (int): The input shape of the network. Defaults to 784. - num_classes (int): Number of classes in the dataset. Defaults to 10. - hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. - num_layers (int): The number of hidden layers. Defaults to 3. - dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (str): Name of the activation function in the hidden layers. Defaults to - relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - if isinstance(hidden_size, int): - hidden_size = [hidden_size] * num_layers - - self.layers = [ - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=input_size, out_features=hidden_size[0]), - activation_fn, - ] - - for i in range(num_layers - 1): - self.layers += [ - nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), - activation_fn, - ] - - if dropout_rate: - self.layers.append(nn.Dropout(p=dropout_rate)) - - self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=num_classes) - ) - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) - - @property - def __name__(self) -> str: - """Returns the name of the network.""" - return "mlp" diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py deleted file mode 100644 index c33f419..0000000 --- a/src/text_recognizer/networks/residual_network.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Rearrange, Reduce -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class Conv2dAuto(nn.Conv2d): - """Convolution with auto padding based on kernel size.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) - - -def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: - """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) - return nn.Sequential( - conv3x3(in_channels, out_channels, *args, **kwargs), - nn.BatchNorm2d(out_channels), - ) - - -class IdentityBlock(nn.Module): - """Residual with identity block.""" - - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu" - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.blocks = nn.Identity() - self.activation_fn = activation_function(activation) - self.shortcut = nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self.apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - x = self.activation_fn(x) - return x - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.out_channels - - -class ResidualBlock(IdentityBlock): - """Residual with nonlinear shortcut.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - expansion: int = 1, - downsampling: int = 1, - *args, - **kwargs - ) -> None: - """Short summary. - - Args: - in_channels (int): Number of in channels. - out_channels (int): umber of out channels. - expansion (int): Expansion factor of the out channels. Defaults to 1. - downsampling (int): Downsampling factor used in stride. Defaults to 1. - *args (type): Extra arguments. - **kwargs (type): Extra key value arguments. - - """ - super().__init__(in_channels, out_channels, *args, **kwargs) - self.expansion = expansion - self.downsampling = downsampling - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.expanded_channels, - kernel_size=1, - stride=self.downsampling, - bias=False, - ), - nn.BatchNorm2d(self.expanded_channels), - ) - if self.apply_shortcut - else None - ) - - @property - def expanded_channels(self) -> int: - """Computes the expanded output channels.""" - return self.out_channels * self.expansion - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.expanded_channels - - -class BasicBlock(ResidualBlock): - """Basic ResNet block.""" - - expansion = 1 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - bias=False, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - bias=False, - ), - ) - - -class BottleNeckBlock(ResidualBlock): - """Bottleneck block to increase depth while minimizing parameter size.""" - - expansion = 4 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=1, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - kernel_size=1, - ), - ) - - -class ResidualLayer(nn.Module): - """ResNet layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - block: BasicBlock = BasicBlock, - num_blocks: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - downsampling = 2 if in_channels != out_channels else 1 - self.blocks = nn.Sequential( - block( - in_channels, out_channels, *args, **kwargs, downsampling=downsampling - ), - *[ - block( - out_channels * block.expansion, - out_channels, - downsampling=1, - *args, - **kwargs - ) - for _ in range(num_blocks - 1) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.blocks(x) - return x - - -class ResidualNetworkEncoder(nn.Module): - """Encoder network.""" - - def __init__( - self, - in_channels: int = 1, - block_sizes: Union[int, List[int]] = (32, 64), - depths: Union[int, List[int]] = (2, 2), - activation: str = "relu", - block: Type[nn.Module] = BasicBlock, - levels: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - self.block_sizes = ( - block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels - ) - self.depths = depths if isinstance(depths, list) else [depths] * levels - self.activation = activation - self.gate = nn.Sequential( - nn.Conv2d( - in_channels=in_channels, - out_channels=self.block_sizes[0], - kernel_size=7, - stride=2, - padding=1, - bias=False, - ), - nn.BatchNorm2d(self.block_sizes[0]), - activation_function(self.activation), - # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), - ) - - self.blocks = self._configure_blocks(block) - - def _configure_blocks( - self, block: Type[nn.Module], *args, **kwargs - ) -> nn.Sequential: - channels = [self.block_sizes[0]] + list( - zip(self.block_sizes, self.block_sizes[1:]) - ) - blocks = [ - ResidualLayer( - in_channels=channels[0], - out_channels=channels[0], - num_blocks=self.depths[0], - block=block, - activation=self.activation, - *args, - **kwargs - ) - ] - blocks += [ - ResidualLayer( - in_channels=in_channels * block.expansion, - out_channels=out_channels, - num_blocks=num_blocks, - block=block, - activation=self.activation, - *args, - **kwargs - ) - for (in_channels, out_channels), num_blocks in zip( - channels[1:], self.depths[1:] - ) - ] - - return nn.Sequential(*blocks) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) - x = self.gate(x) - x = self.blocks(x) - return x - - -class ResidualNetworkDecoder(nn.Module): - """Classification head.""" - - def __init__(self, in_features: int, num_classes: int = 80) -> None: - super().__init__() - self.decoder = nn.Sequential( - Reduce("b c h w -> b c", "mean"), - nn.Linear(in_features=in_features, out_features=num_classes), - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - return self.decoder(x) - - -class ResidualNetwork(nn.Module): - """Full residual network.""" - - def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: - super().__init__() - self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) - self.decoder = ResidualNetworkDecoder( - in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, - num_classes=num_classes, - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.encoder(x) - x = self.decoder(x) - return x diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py deleted file mode 100644 index e9d216f..0000000 --- a/src/text_recognizer/networks/stn.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Spatial Transformer Network.""" - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class SpatialTransformerNetwork(nn.Module): - """A network with differentiable attention. - - Network that learns how to perform spatial transformations on the input image in order to enhance the - geometric invariance of the model. - - # TODO: add arguments to make it more general. - - """ - - def __init__(self) -> None: - super().__init__() - # Initialize the identity transformation and its weights and biases. - linear = nn.Linear(32, 3 * 2) - linear.weight.data.zero_() - linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) - - self.theta = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - Rearrange("b c h w -> b (c h w)", h=3, w=3), - nn.Linear(in_features=10 * 3 * 3, out_features=32), - nn.ReLU(inplace=True), - linear, - Rearrange("b (row col) -> b row col", row=2, col=3), - ) - - def forward(self, x: Tensor) -> Tensor: - """The spatial transformation.""" - grid = F.affine_grid(self.theta(x), x.shape) - return F.grid_sample(x, grid, align_corners=False) diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/src/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/src/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: - https://arxiv.org/abs/1904.02619 - https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: - https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): - """Internal block of a 2D TDSC network.""" - - def __init__( - self, - in_channels: int, - img_depth: int, - kernel_size: Tuple[int], - dropout_rate: float, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.img_depth = img_depth - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.fc_dim = in_channels * img_depth - - # Network placeholders. - self.conv = None - self.mlp = None - self.instance_norm = None - - self._build_block() - - def _build_block(self) -> None: - # Convolutional block. - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), - padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - ) - - # MLP block. - self.mlp = nn.Sequential( - nn.Linear(self.fc_dim, self.fc_dim), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.Linear(self.fc_dim, self.fc_dim), - nn.Dropout(self.dropout_rate), - ) - - # Instance norm. - self.instance_norm = nn.ModuleList( - [ - nn.InstanceNorm2d(self.fc_dim, affine=True), - nn.InstanceNorm2d(self.fc_dim, affine=True), - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, CD, H, W)` - - Returns: - Tensor: Output tensor. - - """ - B, CD, H, W = x.shape - C, D = self.in_channels, self.img_depth - residual = x - x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) - x = self.conv(x) - x = rearrange(x, "b c d h w -> b (c d) h w") - x += residual - - x = self.instance_norm[0](x) - - x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x - x + self.instance_norm[1](x) - - # Output shape: [B, CD, H, W] - return x - - -class TDS2d(nn.Module): - """TDS Netowrk. - - Structure is the following: - Downsample layer -> TDS2d group -> ... -> Linear output layer - - - """ - - def __init__( - self, - input_dim: int, - output_dim: int, - depth: int, - tds_groups: Tuple[int], - kernel_size: Tuple[int], - dropout_rate: float, - in_channels: int = 1, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.input_dim = input_dim - self.output_dim = output_dim - self.depth = depth - self.tds_groups = tds_groups - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - - self.tds = None - self.fc = None - - self._build_network() - - def _build_network(self) -> None: - in_channels = self.in_channels - modules = [] - stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) - if self.input_dim % stride_h: - raise RuntimeError( - f"Image height not divisible by total stride {stride_h}." - ) - - for tds_group in self.tds_groups: - # Add downsample layer. - out_channels = self.depth * tds_group["channels"] - modules.extend( - [ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), - stride=tds_group["stride"], - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.InstanceNorm2d(out_channels, affine=True), - ] - ) - - for _ in range(tds_group["num_blocks"]): - modules.append( - TDSBlock2d( - tds_group["channels"], - self.depth, - self.kernel_size, - self.dropout_rate, - ) - ) - - in_channels = out_channels - - self.tds = nn.Sequential(*modules) - self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, H, W)` - - Returns: - Tensor: Output tensor. - - """ - if len(x.shape) == 4: - x = x.squeeze(1) # Squeeze the channel dim away. - - B, H, W = x.shape - x = rearrange( - x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels - ) - x = self.tds(x) - - # x shape: [B, C, H, W] - x = rearrange(x, "b c h w -> b w (c h)") - - return self.fc(x) diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/src/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): - def test_viterbi(self): - T = 5 - N = 4 - B = 2 - - # fmt: off - emissions1 = torch.tensor(( - 0, 4, 0, 1, - 0, 2, 1, 1, - 0, 0, 0, 2, - 0, 0, 0, 2, - 8, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - emissions2 = torch.tensor(( - 0, 2, 1, 7, - 0, 2, 9, 1, - 0, 0, 0, 2, - 0, 0, 5, 2, - 1, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - # fmt: on - - # Test without blank: - labels = [[1, 3, 0], [3, 2, 3, 2, 3]] - transducer = Transducer( - tokens=["a", "b", "c", "d"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, - blank="none", - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - # Test with blank without repeats: - labels = [[1, 0], [2, 2]] - transducer = Transducer( - tokens=["a", "b", "c"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2}, - blank="optional", - allow_repeats=False, - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/src/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: - scalar = gtn.Graph() - scalar.add_node(True) - scalar.add_node(False, True) - scalar.add_arc(0, 1, 0, 0, weight) - return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: - graph = gtn.Graph(False) - graph.add_node(True) - for i, s in enumerate(sequence): - graph.add_node(False, i == (len(sequence) - 1)) - graph.add_arc(i, i + 1, s) - return graph - - -def make_transitions_graph( - ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: - transitions = gtn.Graph(calc_grad) - transitions.add_node(True, ngram == 1) - - state_map = {(): 0} - - # First build transitions which include : - for n in range(1, ngram): - for state in itertools.product(range(num_tokens), repeat=n): - in_idx = state_map[state[:-1]] - out_idx = transitions.add_node(False, ngram == 1) - state_map[state] = out_idx - transitions.add_arc(in_idx, out_idx, state[-1]) - - for state in itertools.product(range(num_tokens), repeat=ngram): - state_idx = state_map[state[:-1]] - new_state_idx = state_map[state[1:]] - # p(state[-1] | state[:-1]) - transitions.add_arc(state_idx, new_state_idx, state[-1]) - - if ngram > 1: - # Build transitions which include : - end_idx = transitions.add_node(False, True) - for in_idx in range(end_idx): - transitions.add_arc(in_idx, end_idx, gtn.epsilon) - - return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: - """Constructs a graph which transduces letters to word pieces.""" - graph = gtn.Graph(False) - graph.add_node(True, True) - for i, wp in enumerate(word_pieces): - prev = 0 - for l in wp[:-1]: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) - graph.arc_sort() - return graph - - -def make_token_graph( - token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: - """Constructs a graph with all the individual token transition models.""" - if not allow_repeats and blank != "optional": - raise ValueError("Must use blank='optional' if disallowing repeats.") - - ntoks = len(token_list) - graph = gtn.Graph(False) - - # Creating nodes - graph.add_node(True, True) - for i in range(ntoks): - # We can consume one or more consecutive word - # pieces for each emission: - # E.g. [ab, ab, ab] transduces to [ab] - graph.add_node(False, blank != "forced") - - if blank != "none": - graph.add_node() - - # Creating arcs - if blank != "none": - # Blank index is assumed to be last (ntoks) - graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) - graph.add_arc(ntoks + 1, 0, gtn.epsilon) - - for i in range(ntoks): - graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) - graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - - if allow_repeats: - if blank == "forced": - # Allow transitions from token to blank only - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - else: - # Allow transition from token to blank and all other tokens - graph.add_arc(i + 1, 0, gtn.epsilon) - - else: - # allow transitions to blank and all other tokens except the same token - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - for j in range(ntoks): - if i != j: - graph.add_arc(i + 1, j + 1, j, j) - - return graph - - -class TransducerLossFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - inputs, - targets, - tokens, - lexicon, - transition_params=None, - transitions=None, - reduction="none", - ) -> Tensor: - B, T, C = inputs.shape - - losses = [None] * B - emissions_graphs = [None] * B - - if transitions is not None: - if transition_params is None: - raise ValueError("Specified transitions, but not transition params.") - - cpu_data = transition_params.cpu().contiguous() - transitions.set_weights(cpu_data.data_ptr()) - transitions.calc_grad = transition_params.requires_grad - transitions.zero_grad() - - def process(b: int) -> None: - # Create emission graph: - emissions = gtn.linear_graph(T, C, inputs.requires_grad) - cpu_data = inputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - target = make_chain_graph(targets[b]) - target.arc_sort(True) - - # Create token tot grapheme decomposition graph - tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) - tokens_target.arc_sort() - - # Create alignment graph: - aligments = gtn.project_input( - gtn.remove(gtn.compose(tokens, tokens_target)) - ) - aligments.arc_sort() - - # Add transitions scores: - if transitions is not None: - aligments = gtn.intersect(transitions, aligments) - aligments.arc_sort() - - loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - - # Normalize if needed: - if transitions is not None: - norm = gtn.forward_score(gtn.intersect(emissions, transitions)) - loss = gtn.subtract(loss, norm) - - losses[b] = gtn.negate(loss) - - # Save for backward: - if emissions.calc_grad: - emissions_graphs[b] = emissions - - gtn.parallel_for(process, range(B)) - - ctx.graphs = (losses, emissions_graphs, transitions) - ctx.input_shape = inputs.shape - - # Optionally reduce by target length - if reduction == "mean": - scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] - else: - scales = [1.0] * B - - ctx.scales = scales - - loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) - return torch.mean(loss.to(inputs.device)) - - @staticmethod - def backward(ctx, grad_output) -> Tuple: - losses, emissions_graphs, transitions = ctx.graphs - scales = ctx.scales - - B, T, C = ctx.input_shape - calc_emissions = ctx.needs_input_grad[0] - input_grad = torch.empty((B, T, C)) if calc_emissions else None - - def process(b: int) -> None: - scale = make_scalar_graph(scales[b]) - gtn.backward(losses[b], scale) - emissions = emissions_graphs[b] - if calc_emissions: - grad = emissions.grad().weights_to_numpy() - input_grad[b] = torch.tensor(grad).view(1, T, C) - - gtn.parallel_for(process, range(B)) - - if calc_emissions: - input_grad = input_grad.to(grad_output.device) - input_grad *= grad_output / B - - if ctx.needs_input_grad[4]: - grad = transitions.grad().weights_to_numpy() - transition_grad = torch.tensor(grad).to(grad_output.device) - transition_grad *= grad_output / B - else: - transition_grad = None - - return ( - input_grad, - None, # target - None, # tokens - None, # lexicon - transition_grad, # transition params - None, # transitions graph - None, - ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): - def __init__( - self, - tokens: List, - graphemes_to_idx: Dict, - ngram: int = 0, - transitions: str = None, - blank: str = "none", - allow_repeats: bool = True, - reduction: str = "none", - ) -> None: - """A generic transducer loss function. - - Args: - tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) - representing the output tokens of the model (e.g. letters, - word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] - could be a list of sub-word tokens. - graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. - "a", "b", ..) to their corresponding integer index. - ngram (int) : Order of the token-level transition model. If `ngram=0` - then no transition model is used. - blank (string) : Specifies the usage of blank token - 'none' - do not use blank token - 'optional' - allow an optional blank inbetween tokens - 'forced' - force a blank inbetween tokens (also referred to as garbage token) - allow_repeats (boolean) : If false, then we don't allow paths with - consecutive tokens in the alignment graph. This keeps the graph - unambiguous in the sense that the same input cannot transduce to - different outputs. - """ - super().__init__() - if blank not in ["optional", "forced", "none"]: - raise ValueError( - "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" - ) - self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) - self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) - self.ngram = ngram - if ngram > 0 and transitions is not None: - raise ValueError("Only one of ngram and transitions may be specified") - - if ngram > 0: - transitions = make_transitions_graph( - ngram, len(tokens) + int(blank != "none"), True - ) - - if transitions is not None: - self.transitions = transitions - self.transitions.arc_sort() - self.transitions_params = nn.Parameter( - torch.zeros(self.transitions.num_arcs()) - ) - else: - self.transitions = None - self.transitions_params = None - self.reduction = reduction - - def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - TransducerLoss( - inputs, - targets, - self.tokens, - self.lexicon, - self.transitions_params, - self.transitions, - self.reduction, - ) - - def viterbi(self, outputs: Tensor) -> List[Tensor]: - B, T, C = outputs.shape - - if self.transitions is not None: - cpu_data = self.transition_params.cpu().contiguous() - self.transitions.set_weights(cpu_data.data_ptr()) - self.transitions.calc_grad = False - - self.tokens.arc_sort() - - paths = [None] * B - - def process(b: int) -> None: - emissions = gtn.linear_graph(T, C, False) - cpu_data = outputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - - if self.transitions is not None: - full_graph = gtn.intersect(emissions, self.transitions) - else: - full_graph = emissions - - # Find the best path and remove back-off arcs: - path = gtn.remove(gtn.viterbi_path(full_graph)) - - # Left compose the viterbi path with the "aligment to token" - # transducer to get the outputs: - path = gtn.compose(path, self.tokens) - - # When there are ambiguous paths (allow_repeats is true), we take - # the shortest: - path = gtn.viterbi_path(path) - path = gtn.remove(gtn.project_output(path)) - paths[b] = path.labels_to_list() - - gtn.parallel_for(process, range(B)) - predictions = [torch.IntTensor(path) for path in paths] - return predictions - - -def load_transducer_loss( - num_features: int, - ngram: int, - tokens: str, - lexicon: str, - transitions: str, - blank: str, - allow_repeats: bool, - prepend_wordsep: bool = False, - use_words: bool = False, - data_dir: Optional[Union[str, Path]] = None, - reduction: str = "mean", -) -> Tuple[Transducer, int]: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - - preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, - ) - - num_tokens = preprocessor.num_tokens - - criterion = Transducer( - preprocessor.tokens, - preprocessor.graphemes_to_index, - ngram=ngram, - transitions=transitions, - blank=blank, - allow_repeats=allow_repeats, - reduction=reduction, - ) - - return criterion, num_tokens + int(blank != "none") diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py deleted file mode 100644 index 9febc88..0000000 --- a/src/text_recognizer/networks/transformer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transformer modules.""" -from .positional_encoding import PositionalEncoding -from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py deleted file mode 100644 index cce1ecc..0000000 --- a/src/text_recognizer/networks/transformer/attention.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Implementes the attention module for the transformer.""" -from typing import Optional, Tuple - -from einops import rearrange -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class MultiHeadAttention(nn.Module): - """Implementation of multihead attention.""" - - def __init__( - self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.fc_q = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_k = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_v = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - - self._init_weights() - - self.dropout = nn.Dropout(p=dropout_rate) - - def _init_weights(self) -> None: - nn.init.normal_( - self.fc_q.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_k.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_v.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.xavier_normal_(self.fc_out.weight) - - def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tensor: - """Calculates the scaled dot product attention.""" - - # Compute the energy. - energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( - query.shape[-1] - ) - - # If we have a mask for padding some inputs. - if mask is not None: - energy = energy.masked_fill(mask == 0, -np.inf) - - # Compute the attention from the energy. - attention = torch.softmax(energy, dim=3) - - out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) - out = rearrange(out, "b head l v -> b l (head v)") - return out, attention - - def forward( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor]: - """Forward pass for computing the multihead attention.""" - # Get the query, key, and value tensor. - query = rearrange( - self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads - ) - key = rearrange( - self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads - ) - value = rearrange( - self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads - ) - - out, attention = self.scaled_dot_product_attention(query, key, value, mask) - - out = self.fc_out(out) - out = self.dropout(out) - return out, attention diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py deleted file mode 100644 index 1ba5537..0000000 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ /dev/null @@ -1,32 +0,0 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - self.max_len = max_len - - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - x = x + self.pe[:, : x.shape[1]] - return self.dropout(x) diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py deleted file mode 100644 index dd180c4..0000000 --- a/src/text_recognizer/networks/transformer/transformer.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Transfomer module.""" -import copy -from typing import Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.attention import MultiHeadAttention -from text_recognizer.networks.util import activation_function - - -class GEGLU(nn.Module): - """GLU activation for improving feedforward activations.""" - - def __init__(self, dim_in: int, dim_out: int) -> None: - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x: Tensor) -> Tensor: - """Forward propagation.""" - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: - return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) - - -class _IntraLayerConnection(nn.Module): - """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" - - def __init__(self, dropout_rate: float, hidden_dim: int) -> None: - super().__init__() - self.norm = nn.LayerNorm(normalized_shape=hidden_dim) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward(self, src: Tensor, residual: Tensor) -> Tensor: - return self.norm(self.dropout(src) + residual) - - -class _ConvolutionalLayer(nn.Module): - def __init__( - self, - hidden_dim: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - in_projection = ( - nn.Sequential( - nn.Linear(hidden_dim, expansion_dim), activation_function(activation) - ) - if activation != "glu" - else GEGLU(hidden_dim, expansion_dim) - ) - - self.layer = nn.Sequential( - in_projection, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=expansion_dim, out_features=hidden_dim), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layer(x) - - -class EncoderLayer(nn.Module): - """Transfomer encoding layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through the encoder.""" - # First block. - # Multi head attention. - out, _ = self.self_attention(src, src, src, mask) - - # Add & norm. - out = self.block1(out, src) - - # Second block. - # Apply 1D-convolution. - cnn_out = self.cnn(out) - - # Add & norm. - out = self.block2(cnn_out, out) - - return out - - -class Encoder(nn.Module): - """Transfomer encoder module.""" - - def __init__( - self, - num_layers: int, - encoder_layer: Type[nn.Module], - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.norm = norm - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through all encoder layers.""" - for layer in self.layers: - src = layer(src, src_mask) - - if self.norm is not None: - src = self.norm(src) - - return src - - -class DecoderLayer(nn.Module): - """Transfomer decoder layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float = 0.0, - activation: str = "relu", - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.multihead_attention = MultiHeadAttention( - hidden_dim, num_heads, dropout_rate - ) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass of the layer.""" - out, _ = self.self_attention(trg, trg, trg, trg_mask) - trg = self.block1(out, trg) - - out, _ = self.multihead_attention(trg, memory, memory, memory_mask) - trg = self.block2(out, trg) - - out = self.cnn(trg) - out = self.block3(out, trg) - - return out - - -class Decoder(nn.Module): - """Transfomer decoder module.""" - - def __init__( - self, - decoder_layer: Type[nn.Module], - num_layers: int, - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the decoder.""" - for layer in self.layers: - trg = layer(trg, memory, trg_mask, memory_mask) - - if self.norm is not None: - trg = self.norm(trg) - - return trg - - -class Transformer(nn.Module): - """Transformer network.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - # Configure encoder. - encoder_norm = nn.LayerNorm(hidden_dim) - encoder_layer = EncoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) - - # Configure decoder. - decoder_norm = nn.LayerNorm(hidden_dim) - decoder_layer = DecoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - src: Tensor, - trg: Tensor, - src_mask: Optional[Tensor] = None, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the transformer.""" - if src.shape[0] != trg.shape[0]: - print(trg.shape) - raise RuntimeError("The batch size of the src and trg must be the same.") - if src.shape[2] != trg.shape[2]: - raise RuntimeError( - "The number of features for the src and trg must be the same." - ) - - memory = self.encoder(src, src_mask) - output = self.decoder(trg, memory, trg_mask, memory_mask) - return output diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py deleted file mode 100644 index 510910f..0000000 --- a/src/text_recognizer/networks/unet.py +++ /dev/null @@ -1,255 +0,0 @@ -"""UNet for segmentation.""" -from typing import List, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _ConvBlock(nn.Module): - """Modified UNet convolutional block with dilation.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.channels = channels - self.dropout_rate = dropout_rate - self.kernel_size = kernel_size - self.dilation = dilation - self.padding = padding - self.num_groups = num_groups - self.activation = activation_function(activation) - self.block = self._configure_block() - self.residual_conv = nn.Sequential( - nn.Conv2d( - self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 - ), - self.activation, - ) - - def _configure_block(self) -> nn.Sequential: - block = [] - for i in range(len(self.channels) - 1): - block += [ - nn.Dropout(p=self.dropout_rate), - nn.GroupNorm(self.num_groups, self.channels[i]), - self.activation, - nn.Conv2d( - self.channels[i], - self.channels[i + 1], - kernel_size=self.kernel_size, - padding=self.padding, - stride=1, - dilation=self.dilation, - ), - ] - - return nn.Sequential(*block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the convolutional block.""" - residual = self.residual_conv(x) - return self.block(x) + residual - - -class _DownSamplingBlock(nn.Module): - """Basic down sampling block.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - pooling_kernel: Union[int, bool] = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Return the convolutional block output and a down sampled tensor.""" - x = self.conv_block(x) - x_down = self.down_sampling(x) if self.down_sampling is not None else x - - return x_down, x - - -class _UpSamplingBlock(nn.Module): - """The upsampling block of the UNet.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - scale_factor: int = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.up_sampling = nn.Upsample( - scale_factor=scale_factor, mode="bilinear", align_corners=True - ) - - def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: - """Apply the up sampling and convolutional block.""" - x = self.up_sampling(x) - if x_skip is not None: - x = torch.cat((x, x_skip), dim=1) - return self.conv_block(x) - - -class UNet(nn.Module): - """UNet architecture.""" - - def __init__( - self, - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 3, - depth: int = 4, - activation: str = "relu", - num_groups: int = 8, - dropout_rate: float = 0.1, - pooling_kernel: int = 2, - scale_factor: int = 2, - kernel_size: Optional[List[int]] = None, - dilation: Optional[List[int]] = None, - padding: Optional[List[int]] = None, - ) -> None: - super().__init__() - self.depth = depth - self.num_groups = num_groups - - if kernel_size is not None and dilation is not None and padding is not None: - if ( - len(kernel_size) != depth - and len(dilation) != depth - and len(padding) != depth - ): - raise RuntimeError( - "Length of convolutional parameters does not match the depth." - ) - self.kernel_size = kernel_size - self.padding = padding - self.dilation = dilation - - else: - self.kernel_size = [3] * depth - self.padding = [1] * depth - self.dilation = [1] * depth - - self.dropout_rate = dropout_rate - self.conv = nn.Conv2d( - in_channels, base_channels, kernel_size=3, stride=1, padding=1 - ) - - channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] - self.encoder_blocks = self._configure_down_sampling_blocks( - channels, activation, pooling_kernel - ) - self.decoder_blocks = self._configure_up_sampling_blocks( - channels, activation, scale_factor - ) - - self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) - - def _configure_down_sampling_blocks( - self, channels: List[int], activation: str, pooling_kernel: int - ) -> nn.ModuleList: - blocks = nn.ModuleList([]) - for i in range(len(channels) - 1): - pooling_kernel = pooling_kernel if i < self.depth - 1 else False - dropout_rate = self.dropout_rate if i < 0 else 0 - blocks += [ - _DownSamplingBlock( - [channels[i], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - pooling_kernel, - dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - ] - - return blocks - - def _configure_up_sampling_blocks( - self, channels: List[int], activation: str, scale_factor: int, - ) -> nn.ModuleList: - channels.reverse() - self.kernel_size.reverse() - self.dilation.reverse() - self.padding.reverse() - return nn.ModuleList( - [ - _UpSamplingBlock( - [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - scale_factor, - self.dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - for i in range(len(channels) - 2) - ] - ) - - def _encode(self, x: Tensor) -> List[Tensor]: - x_skips = [] - for block in self.encoder_blocks: - x, x_skip = block(x) - x_skips.append(x_skip) - return x_skips - - def _decode(self, x_skips: List[Tensor]) -> Tensor: - x = x_skips[-1] - for i, block in enumerate(self.decoder_blocks): - x = block(x, x_skips[-(i + 2)]) - return x - - def forward(self, x: Tensor) -> Tensor: - """Forward pass with the UNet model.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - x = self.conv(x) - x_skips = self._encode(x) - x = self._decode(x_skips) - return self.head(x) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py deleted file mode 100644 index 131a6b4..0000000 --- a/src/text_recognizer/networks/util.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Miscellaneous neural network functionality.""" -import importlib -from pathlib import Path -from typing import Dict, Tuple, Type - -from einops import rearrange -from loguru import logger -import torch -from torch import nn - - -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the sliding window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], - ) - return patches - - -def activation_function(activation: str) -> Type[nn.Module]: - """Returns the callable activation function.""" - activation_fns = nn.ModuleDict( - [ - ["elu", nn.ELU(inplace=True)], - ["gelu", nn.GELU()], - ["glu", nn.GLU()], - ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], - ["none", nn.Identity()], - ["relu", nn.ReLU(inplace=True)], - ["selu", nn.SELU(inplace=True)], - ] - ) - return activation_fns[activation.lower()] - - -def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: - """Loads a backbone network.""" - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( - "pretrained" - ) - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - freeze = False - if "freeze" in backbone_args and backbone_args["freeze"] is True: - backbone_args.pop("freeze") - freeze = True - network_args = backbone_args - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if freeze: - for params in backbone.parameters(): - params.requires_grad = False - else: - backbone_ = getattr(network_module, backbone) - backbone = backbone_(**backbone_args) - - if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: - backbone = nn.Sequential( - *list(backbone.children())[:][: -backbone_args["remove_layers"]] - ) - - return backbone diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py deleted file mode 100644 index efb3701..0000000 --- a/src/text_recognizer/networks/vit.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A Vision Transformer. - -Inspired by: -https://openreview.net/pdf?id=YicbFdNTTy - -""" -from typing import Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import Transformer - - -class ViT(nn.Module): - """Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - expansion_dim: int, - patch_dim: Tuple[int, int], - image_size: Tuple[int, int], - dropout_rate: float, - trg_pad_index: int, - max_len: int, - activation: str = "gelu", - ) -> None: - super().__init__() - - self.trg_pad_index = trg_pad_index - self.patch_dim = patch_dim - self.num_patches = image_size[-1] // self.patch_dim[1] - - # Encoder - self.patch_to_embedding = nn.Linear( - self.patch_dim[0] * self.patch_dim[1], hidden_dim - ) - self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.dropout = nn.Dropout(dropout_rate) - self._init() - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _init(self) -> None: - nn.init.normal_(self.character_embedding.weight, std=0.02) - # nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - patches = rearrange( - src, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim[0], - p2=self.patch_dim[1], - ) - - # From patches to encoded sequence. - x = self.patch_to_embedding(patches) - b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, : (n + 1)] - x = self.dropout(x) - - return x - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - _, n = trg.shape - trg = self.character_embedding(trg.long()) - trg += self.pos_embedding[:, :n] - return trg - - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) - return logits diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index c673d96..0000000 --- a/src/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A VQ-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class VQTransformer(nn.Module): - """VQ+Transfomer for image to character sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - ) -> None: - super().__init__() - - # Configure vector quantized backbone. - self.backbone = configure_backbone(backbone, backbone_args) - self.conv = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), - nn.ReLU(inplace=True), - ) - - # Configure embeddings for Transformer network. - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None - ) - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: The input src to the transformer and the vq loss. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - src, vq_loss = self.backbone.encode(src) - # src = self.backbone.decoder.res_block(src) - src = self.conv(src) - - if self.adaptive_pool is not None: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - else: - src = rearrange(src, "b c h w -> b (w h) c") - - b, t, _ = src.shape - - src += self.src_position_embedding[:, :t] - - return src, vq_loss - - def target_embedding(self, trg: Tensor) -> Tensor: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tensor: Encoded target tensor. - - """ - trg = self.character_embedding(trg.long()) - trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features, vq_loss = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits, vq_loss diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py deleted file mode 100644 index 763953c..0000000 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""VQ-VAE module.""" -from .decoder import Decoder -from .encoder import Encoder -from .vector_quantizer import VectorQuantizer -from .vqvae import VQVAE diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py deleted file mode 100644 index 8847aba..0000000 --- a/src/text_recognizer/networks/vqvae/decoder.py +++ /dev/null @@ -1,133 +0,0 @@ -"""CNN decoder for the VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class Decoder(nn.Module): - """A CNN encoder network.""" - - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - upsampling: Optional[List[List[int]]] = None, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.upsampling = upsampling - - self.res_block = nn.ModuleList([]) - self.upsampling_block = nn.ModuleList([]) - - self.embedding_dim = embedding_dim - activation = activation_function(activation) - - # Configure encoder. - self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, - ) - - def _build_decompression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - modules.append( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - activation, - ) - ) - - if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - modules.extend( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 - ), - nn.Tanh(), - ) - ) - - return modules - - def _build_decoder( - self, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - - self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) - ) - - # Bottleneck module. - self.res_block.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[0], channels[0], dropout) - for i in range(num_residual_layers) - ] - ) - ) - - # Decompression module - self.upsampling_block.extend( - self._build_decompression_block( - channels[0], channels[1:], kernel_sizes, strides, activation, dropout - ) - ) - - self.res_block = nn.Sequential(*self.res_block) - self.upsampling_block = nn.Sequential(*self.upsampling_block) - - return nn.Sequential(self.res_block, self.upsampling_block) - - def forward(self, z_q: Tensor) -> Tensor: - """Reconstruct input from given codes.""" - x_reconstruction = self.decoder(z_q) - return x_reconstruction diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py deleted file mode 100644 index d3adac5..0000000 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ /dev/null @@ -1,147 +0,0 @@ -"""CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) - - -class Encoder(nn.Module): - """A CNN encoder network.""" - - def __init__( - self, - in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) - - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - def _build_compression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules - - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) - - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout - ) - ) - - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] - ) - ) - - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py deleted file mode 100644 index f92c7ee..0000000 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Implementation of a Vector Quantized Variational AutoEncoder. - -Reference: -https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py - -""" - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F - - -class VectorQuantizer(nn.Module): - """The codebook that contains quantized vectors.""" - - def __init__( - self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 - ) -> None: - super().__init__() - self.K = num_embeddings - self.D = embedding_dim - self.beta = beta - - self.embedding = nn.Embedding(self.K, self.D) - - # Initialize the codebook. - nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) - - def discretization_bottleneck(self, latent: Tensor) -> Tensor: - """Computes the code nearest to the latent representation. - - First we compute the posterior categorical distribution, and then map - the latent representation to the nearest element of the embedding. - - Args: - latent (Tensor): The latent representation. - - Shape: - - latent :math:`(B x H x W, D)` - - Returns: - Tensor: The quantized embedding vector. - - """ - # Store latent shape. - b, h, w, d = latent.shape - - # Flatten the latent representation to 2D. - latent = rearrange(latent, "b h w d -> (b h w) d") - - # Compute the L2 distance between the latents and the embeddings. - l2_distance = ( - torch.sum(latent ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) - - 2 * latent @ self.embedding.weight.t() - ) # [BHW x K] - - # Find the embedding k nearest to each latent. - encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] - - # Convert to one-hot encodings, aka discrete bottleneck. - one_hot_encoding = torch.zeros( - encoding_indices.shape[0], self.K, device=latent.device - ) - one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] - - # Embedding quantization. - quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] - quantized_latent = rearrange( - quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w - ) - - return quantized_latent - - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: - """Vector Quantization loss. - - The vector quantization algorithm allows us to create a codebook. The VQ - algorithm works by moving the embedding vectors towards the encoder outputs. - - The embedding loss moves the embedding vector towards the encoder outputs. The - .detach() works as the stop gradient (sg) described in the paper. - - Because the volume of the embedding space is dimensionless, it can arbitarily - grow if the embeddings are not trained as fast as the encoder parameters. To - mitigate this, a commitment loss is added in the second term which makes sure - that the encoder commits to an embedding and that its output does not grow. - - Args: - latent (Tensor): The encoder output. - quantized_latent (Tensor): The quantized latent. - - Returns: - Tensor: The combinded VQ loss. - - """ - embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - return embedding_loss + self.beta * commitment_loss - - def forward(self, latent: Tensor) -> Tensor: - """Forward pass that returns the quantized vector and the vq loss.""" - # Rearrange latent representation s.t. the hidden dim is at the end. - latent = rearrange(latent, "b d h w -> b h w d") - - # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) - - loss = self.vq_loss(latent, quantized_latent) - - # Add residue to the quantized latent. - quantized_latent = latent + (quantized_latent - latent).detach() - - # Rearrange the quantized shape back to the original shape. - quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") - - return quantized_latent, loss diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py deleted file mode 100644 index 50448b4..0000000 --- a/src/text_recognizer/networks/vqvae/vqvae.py +++ /dev/null @@ -1,74 +0,0 @@ -"""The VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.vqvae import Decoder, Encoder - - -class VQVAE(nn.Module): - """Vector Quantized Variational AutoEncoder.""" - - def __init__( - self, - in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - upsampling: Optional[List[List[int]]] = None, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - # configure encoder. - self.encoder = Encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - num_embeddings, - beta, - activation, - dropout_rate, - ) - - # Configure decoder. - channels.reverse() - kernel_sizes.reverse() - strides.reverse() - self.decoder = Decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - upsampling, - activation, - dropout_rate, - ) - - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes input to a latent code.""" - return self.encoder(x) - - def decode(self, z_q: Tensor) -> Tensor: - """Reconstructs input from latent codes.""" - return self.decoder(z_q) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Compresses and decompresses input.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - z_q, vq_loss = self.encode(x) - x_reconstruction = self.decode(z_q) - return x_reconstruction, vq_loss diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py deleted file mode 100644 index b767778..0000000 --- a/src/text_recognizer/networks/wide_resnet.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Wide Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Reduce -import numpy as np -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """Helper function for a 3x3 2d convolution.""" - return nn.Conv2d( - in_channels=in_planes, - out_channels=out_planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False, - ) - - -def conv_init(module: Type[nn.Module]) -> None: - """Initializes the weights for convolution and batchnorms.""" - classname = module.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) - nn.init.constant_(module.bias, 0) - elif classname.find("BatchNorm") != -1: - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) - - -class WideBlock(nn.Module): - """Block used in WideResNet.""" - - def __init__( - self, - in_planes: int, - out_planes: int, - dropout_rate: float, - stride: int = 1, - activation: str = "relu", - ) -> None: - super().__init__() - self.in_planes = in_planes - self.out_planes = out_planes - self.dropout_rate = dropout_rate - self.stride = stride - self.activation = activation_function(activation) - - # Build blocks. - self.blocks = nn.Sequential( - nn.BatchNorm2d(self.in_planes), - self.activation, - conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), - nn.Dropout(p=self.dropout_rate), - nn.BatchNorm2d(self.out_planes), - self.activation, - conv3x3( - in_planes=self.out_planes, - out_planes=self.out_planes, - stride=self.stride, - ), - ) - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_planes, - out_channels=self.out_planes, - kernel_size=1, - stride=self.stride, - bias=False, - ), - ) - if self._apply_shortcut - else None - ) - - @property - def _apply_shortcut(self) -> bool: - """If shortcut should be applied or not.""" - return self.stride != 1 or self.in_planes != self.out_planes - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self._apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - return x - - -class WideResidualNetwork(nn.Module): - """WideResNet for character predictions. - - Can be used for classification or encoding of images to a latent vector. - - """ - - def __init__( - self, - in_channels: int = 1, - in_planes: int = 16, - num_classes: int = 80, - depth: int = 16, - width_factor: int = 10, - dropout_rate: float = 0.0, - num_layers: int = 3, - block: Type[nn.Module] = WideBlock, - num_stages: Optional[List[int]] = None, - activation: str = "relu", - use_decoder: bool = True, - ) -> None: - """The initialization of the WideResNet. - - Args: - in_channels (int): Number of input channels. Defaults to 1. - in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. - num_classes (int): Number of classes. Defaults to 80. - depth (int): Set the number of blocks to use. Defaults to 16. - width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. - dropout_rate (float): The dropout rate. Defaults to 0.0. - num_layers (int): Number of layers of blocks. Defaults to 3. - block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. - num_stages (List[int]): If given, will use these channel values. Defaults to None. - activation (str): Name of the activation to use. Defaults to "relu". - use_decoder (bool): If True, the network output character predictions, if False, the network outputs a - latent vector. Defaults to True. - - Raises: - RuntimeError: If the depth is not of the size `6n+4`. - - """ - - super().__init__() - if (depth - 4) % 6 != 0: - raise RuntimeError("Wide-resnet depth should be 6n+4") - self.in_channels = in_channels - self.in_planes = in_planes - self.num_classes = num_classes - self.num_blocks = (depth - 4) // 6 - self.width_factor = width_factor - self.num_layers = num_layers - self.block = block - self.dropout_rate = dropout_rate - self.activation = activation_function(activation) - - if num_stages is None: - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor - for n in range(self.num_layers) - ] - else: - self.num_stages = [self.in_planes] + num_stages - - self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) - self.strides = [1] + [2] * (self.num_layers - 1) - - self.encoder = nn.Sequential( - conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), - *[ - self._configure_wide_layer( - in_planes=in_planes, - out_planes=out_planes, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip( - self.num_stages, self.strides - ) - ], - ) - - self.decoder = ( - nn.Sequential( - nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), - self.activation, - Reduce("b c h w -> b c", "mean"), - nn.Linear( - in_features=self.num_stages[-1][-1], out_features=self.num_classes - ), - ) - if use_decoder - else None - ) - - # self.apply(conv_init) - - def _configure_wide_layer( - self, in_planes: int, out_planes: int, stride: int, activation: str - ) -> List: - strides = [stride] + [1] * (self.num_blocks - 1) - planes = [out_planes] * len(strides) - planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) - return nn.Sequential( - *[ - self.block( - in_planes=in_planes, - out_planes=out_planes, - dropout_rate=self.dropout_rate, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip(planes, strides) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass.""" - if len(x.shape) < 4: - x = x[(None,) * int(4 - len(x.shape))] - x = self.encoder(x) - if self.decoder is not None: - x = self.decoder(x) - return x -- cgit v1.2.3-70-g09d2