diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/__init__.py | 43 | ||||
-rw-r--r-- | text_recognizer/networks/beam.py | 83 | ||||
-rw-r--r-- | text_recognizer/networks/cnn.py | 101 | ||||
-rw-r--r-- | text_recognizer/networks/crnn.py | 110 | ||||
-rw-r--r-- | text_recognizer/networks/ctc.py | 58 | ||||
-rw-r--r-- | text_recognizer/networks/densenet.py | 225 | ||||
-rw-r--r-- | text_recognizer/networks/lenet.py | 68 | ||||
-rw-r--r-- | text_recognizer/networks/metrics.py | 123 | ||||
-rw-r--r-- | text_recognizer/networks/mlp.py | 73 | ||||
-rw-r--r-- | text_recognizer/networks/stn.py | 44 | ||||
-rw-r--r-- | text_recognizer/networks/unet.py | 255 | ||||
-rw-r--r-- | text_recognizer/networks/vit.py | 150 |
12 files changed, 0 insertions, 1333 deletions
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 1521355..e69de29 100644 --- a/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py @@ -1,43 +0,0 @@ -"""Network modules.""" -from .cnn import CNN -from .cnn_transformer import CNNTransformer -from .crnn import ConvolutionalRecurrentNetwork -from .ctc import greedy_decoder -from .densenet import DenseNet -from .lenet import LeNet -from .metrics import accuracy, cer, wer -from .mlp import MLP -from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import load_transducer_loss, TDS2d -from .transformer import Transformer -from .unet import UNet -from .util import sliding_window -from .vit import ViT -from .vq_transformer import VQTransformer -from .vqvae import VQVAE -from .wide_resnet import WideResidualNetwork - -__all__ = [ - "accuracy", - "cer", - "CNN", - "CNNTransformer", - "ConvolutionalRecurrentNetwork", - "DenseNet", - "FCN", - "greedy_decoder", - "MLP", - "LeNet", - "load_transducer_loss", - "ResidualNetwork", - "ResidualNetworkEncoder", - "sliding_window", - "UNet", - "TDS2d", - "Transformer", - "ViT", - "VQTransformer", - "VQVAE", - "wer", - "WideResidualNetwork", -] diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py deleted file mode 100644 index dccccdb..0000000 --- a/text_recognizer/networks/beam.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Implementation of beam search decoder for a sequence to sequence network. - -Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py - -""" -# from typing import List -# from Queue import PriorityQueue - -# from loguru import logger -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F - - -# class Node: -# def __init__( -# self, parent: Node, target_index: int, log_prob: Tensor, length: int -# ) -> None: -# self.parent = parent -# self.target_index = target_index -# self.log_prob = log_prob -# self.length = length -# self.reward = 0.0 - -# def eval(self, alpha: float = 1.0) -> Tensor: -# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward - - -# @torch.no_grad() -# def beam_decoder( -# network, mapper, device, memory: Tensor = None, max_len: int = 97, -# ) -> Tensor: -# beam_width = 10 -# topk = 1 # How many sentences to generate. - -# trg_indices = [mapper(mapper.init_token)] - -# end_nodes = [] - -# node = Node(None, trg_indices, 0, 1) -# nodes = PriorityQueue() - -# nodes.put((node.eval(), node)) -# q_size = 1 - -# # Beam search -# for _ in range(max_len): -# if q_size > 2000: -# logger.warning("Could not decoder input") -# break - -# # Fetch the best node. -# score, n = nodes.get() -# decoder_input = n.target_index - -# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: -# end_nodes.append((score, n)) - -# # If we reached the maximum number of sentences required. -# if len(end_nodes) >= 1: -# break -# else: -# continue - -# # Forward pass with transformer. -# trg = torch.tensor(trg_indices, device=device)[None, :].long() -# trg = network.target_embedding(trg) -# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) -# log_prob = F.log_softmax(logits, dim=2) - -# log_prob, indices = torch.topk(log_prob, beam_width) - -# for new_k in range(beam_width): -# # TODO: continue from here -# token_index = indices[0][new_k].view(1, -1) -# log_p = log_prob[0][new_k].item() - -# node = Node() - -# pass - -# pass diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py deleted file mode 100644 index 1807bb9..0000000 --- a/text_recognizer/networks/cnn.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Implementation of a simple backbone cnn network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class CNN(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64, 128), - kernel_sizes: Tuple[int, ...] = (4, 4, 4), - strides: Tuple[int, ...] = (2, 2, 2), - max_pool_kernel: int = 2, - dropout_rate: float = 0.2, - activation: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). - max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. - dropout_rate (float): The dropout rate. Defaults to 0.2. - activation (Optional[str]): The name of non-linear activation function. Defaults to relu. - - Raises: - RuntimeError: if the number of hyperparameters does not match in length. - - """ - super().__init__() - - if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): - raise RuntimeError("The number of the hyperparameters does not match.") - - self.cnn = self._build_network( - channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, - ) - - def _build_network( - self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - strides: Tuple[int, ...], - max_pool_kernel: int, - dropout_rate: float, - activation: str, - ) -> nn.Sequential: - # Load activation function. - activation_fn = activation_function(activation) - - channels = list(channels) - in_channels = channels.pop(0) - configuration = zip(channels, kernel_sizes, strides) - - modules = nn.ModuleList([]) - - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - # Add max pool to reduce output size. - if i == len(channels) // 2: - modules.append(nn.MaxPool2d(max_pool_kernel)) - if i == 0: - modules.append( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ) - ) - else: - modules.append( - nn.Sequential( - activation_fn, - nn.BatchNorm2d(in_channels), - nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - ) - ) - - if dropout_rate: - modules.append(nn.Dropout2d(p=dropout_rate)) - - in_channels = out_channels - - return nn.Sequential(*modules) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.cnn(x) diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py deleted file mode 100644 index 778e232..0000000 --- a/text_recognizer/networks/crnn.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CRNN for handwritten text recognition.""" -from typing import Dict, Tuple - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import configure_backbone - - -class ConvolutionalRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - recurrent_cell: str = "lstm", - avg_pool: bool = False, - use_sliding_window: bool = True, - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = ( - self._configure_sliding_window() if use_sliding_window else None - ) - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = configure_backbone(backbone, backbone_args) - self.bidirectional = bidirectional - self.avg_pool = avg_pool - - if recurrent_cell.upper() in ["LSTM", "GRU"]: - recurrent_cell = getattr(nn, recurrent_cell) - else: - logger.warning( - f"Option {recurrent_cell} not valid, defaulting to LSTM cell." - ) - recurrent_cell = nn.LSTM - - self.rnn = recurrent_cell( - input_size=self.input_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - if self.sliding_window is not None: - # Create image patches with a sliding window kernel. - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - - x = self.backbone(x) - - # Average pooling. - if self.avg_pool: - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) - else: - x = rearrange(x, "(b t) h -> t b h", b=b, t=t) - else: - # Encode the entire image with a CNN, and use the channels as temporal dimension. - x = self.backbone(x) - x = rearrange(x, "b c h w -> b w c h") - if self.adaptive_pool is not None: - x = self.adaptive_pool(x) - x = x.squeeze(3) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classification layer. - x = self.decoder(x) - return x diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py deleted file mode 100644 index af9b700..0000000 --- a/text_recognizer/networks/ctc.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decodes the CTC output.""" -from typing import Callable, List, Optional, Tuple - -from einops import rearrange -import torch -from torch import Tensor - -from text_recognizer.datasets.util import EmnistMapper - - -def greedy_decoder( - predictions: Tensor, - targets: Optional[Tensor] = None, - target_lengths: Optional[Tensor] = None, - character_mapper: Optional[Callable] = None, - blank_label: int = 79, - collapse_repeated: bool = True, -) -> Tuple[List[str], List[str]]: - """Greedy CTC decoder. - - Args: - predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. - targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. - target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. - character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults - to None. - blank_label (int): The blank character to be ignored. Defaults to 80. - collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. - - Returns: - Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. - - """ - - if character_mapper is None: - character_mapper = EmnistMapper(pad_token="_") # noqa: S106 - - predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") - decoded_predictions = [] - decoded_targets = [] - for i, prediction in enumerate(predictions): - decoded_prediction = [] - decoded_target = [] - if targets is not None and target_lengths is not None: - for target_index in targets[i][: target_lengths[i]]: - if target_index == blank_label: - continue - decoded_target.append(character_mapper(int(target_index))) - decoded_targets.append(decoded_target) - for j, index in enumerate(prediction): - if index != blank_label: - if collapse_repeated and j != 0 and index == prediction[j - 1]: - continue - decoded_prediction.append(index.item()) - decoded_predictions.append( - [character_mapper(int(pred_index)) for pred_index in decoded_prediction] - ) - return decoded_predictions, decoded_targets diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): - """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - - def __init__( - self, - in_channels: int, - growth_rate: int, - bn_size: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.dense_layer = [ - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=bn_size * growth_rate, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm2d(bn_size * growth_rate), - activation_fn, - nn.Conv2d( - in_channels=bn_size * growth_rate, - out_channels=growth_rate, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - ] - if dropout_rate: - self.dense_layer.append(nn.Dropout(p=dropout_rate)) - - self.dense_layer = nn.Sequential(*self.dense_layer) - - def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: - if isinstance(x, list): - x = torch.cat(x, 1) - return self.dense_layer(x) - - -class _DenseBlock(nn.Module): - def __init__( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, - ) - - def _build_dense_blocks( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> nn.ModuleList: - dense_block = [] - for i in range(num_layers): - dense_block.append( - _DenseLayer( - in_channels=in_channels + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - return nn.ModuleList(dense_block) - - def forward(self, x: Tensor) -> Tensor: - feature_maps = [x] - for layer in self.dense_block: - x = layer(feature_maps) - feature_maps.append(x) - return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.transition = nn.Sequential( - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - bias=False, - ), - nn.AvgPool2d(kernel_size=2, stride=2), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.transition(x) - - -class DenseNet(nn.Module): - """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - - def __init__( - self, - growth_rate: int = 32, - block_config: List[int] = (6, 12, 24, 16), - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 80, - bn_size: int = 4, - dropout_rate: float = 0, - classifier: bool = True, - activation: str = "relu", - ) -> None: - super().__init__() - self.densenet = self._configure_densenet( - in_channels, - base_channels, - num_classes, - growth_rate, - block_config, - bn_size, - dropout_rate, - classifier, - activation, - ) - - def _configure_densenet( - self, - in_channels: int, - base_channels: int, - num_classes: int, - growth_rate: int, - block_config: List[int], - bn_size: int, - dropout_rate: float, - classifier: bool, - activation: str, - ) -> nn.Sequential: - activation_fn = activation_function(activation) - densenet = [ - nn.Conv2d( - in_channels=in_channels, - out_channels=base_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(base_channels), - activation_fn, - ] - - num_features = base_channels - - for i, num_layers in enumerate(block_config): - densenet.append( - _DenseBlock( - num_layers=num_layers, - in_channels=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - densenet.append( - _Transition( - in_channels=num_features, - out_channels=num_features // 2, - activation=activation, - ) - ) - num_features = num_features // 2 - - densenet.append(activation_fn) - - if classifier: - densenet.append(nn.AdaptiveAvgPool2d((1, 1))) - densenet.append(Rearrange("b c h w -> b (c h w)")) - densenet.append( - nn.Linear(in_features=num_features, out_features=num_classes) - ) - - return nn.Sequential(*densenet) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass of Densenet.""" - # If batch dimenstion is missing, it will be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.densenet(x) diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py deleted file mode 100644 index 527e1a0..0000000 --- a/text_recognizer/networks/lenet.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Implementation of the LeNet network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class LeNet(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64), - kernel_sizes: Tuple[int, ...] = (3, 3, 2), - hidden_size: Tuple[int, ...] = (9216, 128), - dropout_rate: float = 0.2, - num_classes: int = 10, - activation_fn: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - Defaults to (9216, 128). - dropout_rate (float): The dropout rate. Defaults to 0.2. - num_classes (int): Number of classes. Defaults to 10. - activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - self.layers = [ - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - ), - activation_fn, - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[2], - kernel_size=kernel_sizes[1], - ), - activation_fn, - nn.MaxPool2d(kernel_sizes[2]), - nn.Dropout(p=dropout_rate), - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), - activation_fn, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=num_classes), - ] - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py deleted file mode 100644 index 2605731..0000000 --- a/text_recognizer/networks/metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Utility functions for models.""" -from typing import Optional - -from einops import rearrange -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: - """Computes the accuracy. - - Args: - outputs (Tensor): The output from the network. - labels (Tensor): Ground truth labels. - pad_index (int): Padding index. - - Returns: - float: The accuracy for the batch. - - """ - - _, predicted = torch.max(outputs, dim=-1) - - # Mask out the pad tokens - mask = labels != pad_index - - predicted *= mask - labels *= mask - - acc = (predicted == labels).sum().float() / labels.shape[0] - acc = acc.item() - return acc - - -def cer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the character error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (Optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The cer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - prediction, target = ( - prediction.replace(" ", ""), - target.replace(" ", ""), - ) - lev_dist += Lev.distance(prediction, target) - return lev_dist / len(decoded_predictions) - - -def wer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the Word error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The wer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - - b = set(prediction.split() + target.split()) - word2char = dict(zip(b, range(len(b)))) - - w1 = [chr(word2char[w]) for w in prediction.split()] - w2 = [chr(word2char[w]) for w in target.split()] - - lev_dist += Lev.distance("".join(w1), "".join(w2)) - - return lev_dist / len(decoded_predictions) diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py deleted file mode 100644 index 1101912..0000000 --- a/text_recognizer/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Defines the MLP network.""" -from typing import Callable, Dict, List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class MLP(nn.Module): - """Multi layered perceptron network.""" - - def __init__( - self, - input_size: int = 784, - num_classes: int = 10, - hidden_size: Union[int, List] = 128, - num_layers: int = 3, - dropout_rate: float = 0.2, - activation_fn: str = "relu", - ) -> None: - """Initialization of the MLP network. - - Args: - input_size (int): The input shape of the network. Defaults to 784. - num_classes (int): Number of classes in the dataset. Defaults to 10. - hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. - num_layers (int): The number of hidden layers. Defaults to 3. - dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (str): Name of the activation function in the hidden layers. Defaults to - relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - if isinstance(hidden_size, int): - hidden_size = [hidden_size] * num_layers - - self.layers = [ - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=input_size, out_features=hidden_size[0]), - activation_fn, - ] - - for i in range(num_layers - 1): - self.layers += [ - nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), - activation_fn, - ] - - if dropout_rate: - self.layers.append(nn.Dropout(p=dropout_rate)) - - self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=num_classes) - ) - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) - - @property - def __name__(self) -> str: - """Returns the name of the network.""" - return "mlp" diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py deleted file mode 100644 index e9d216f..0000000 --- a/text_recognizer/networks/stn.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Spatial Transformer Network.""" - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class SpatialTransformerNetwork(nn.Module): - """A network with differentiable attention. - - Network that learns how to perform spatial transformations on the input image in order to enhance the - geometric invariance of the model. - - # TODO: add arguments to make it more general. - - """ - - def __init__(self) -> None: - super().__init__() - # Initialize the identity transformation and its weights and biases. - linear = nn.Linear(32, 3 * 2) - linear.weight.data.zero_() - linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) - - self.theta = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - Rearrange("b c h w -> b (c h w)", h=3, w=3), - nn.Linear(in_features=10 * 3 * 3, out_features=32), - nn.ReLU(inplace=True), - linear, - Rearrange("b (row col) -> b row col", row=2, col=3), - ) - - def forward(self, x: Tensor) -> Tensor: - """The spatial transformation.""" - grid = F.affine_grid(self.theta(x), x.shape) - return F.grid_sample(x, grid, align_corners=False) diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py deleted file mode 100644 index 510910f..0000000 --- a/text_recognizer/networks/unet.py +++ /dev/null @@ -1,255 +0,0 @@ -"""UNet for segmentation.""" -from typing import List, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _ConvBlock(nn.Module): - """Modified UNet convolutional block with dilation.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.channels = channels - self.dropout_rate = dropout_rate - self.kernel_size = kernel_size - self.dilation = dilation - self.padding = padding - self.num_groups = num_groups - self.activation = activation_function(activation) - self.block = self._configure_block() - self.residual_conv = nn.Sequential( - nn.Conv2d( - self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 - ), - self.activation, - ) - - def _configure_block(self) -> nn.Sequential: - block = [] - for i in range(len(self.channels) - 1): - block += [ - nn.Dropout(p=self.dropout_rate), - nn.GroupNorm(self.num_groups, self.channels[i]), - self.activation, - nn.Conv2d( - self.channels[i], - self.channels[i + 1], - kernel_size=self.kernel_size, - padding=self.padding, - stride=1, - dilation=self.dilation, - ), - ] - - return nn.Sequential(*block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the convolutional block.""" - residual = self.residual_conv(x) - return self.block(x) + residual - - -class _DownSamplingBlock(nn.Module): - """Basic down sampling block.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - pooling_kernel: Union[int, bool] = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Return the convolutional block output and a down sampled tensor.""" - x = self.conv_block(x) - x_down = self.down_sampling(x) if self.down_sampling is not None else x - - return x_down, x - - -class _UpSamplingBlock(nn.Module): - """The upsampling block of the UNet.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - scale_factor: int = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.up_sampling = nn.Upsample( - scale_factor=scale_factor, mode="bilinear", align_corners=True - ) - - def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: - """Apply the up sampling and convolutional block.""" - x = self.up_sampling(x) - if x_skip is not None: - x = torch.cat((x, x_skip), dim=1) - return self.conv_block(x) - - -class UNet(nn.Module): - """UNet architecture.""" - - def __init__( - self, - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 3, - depth: int = 4, - activation: str = "relu", - num_groups: int = 8, - dropout_rate: float = 0.1, - pooling_kernel: int = 2, - scale_factor: int = 2, - kernel_size: Optional[List[int]] = None, - dilation: Optional[List[int]] = None, - padding: Optional[List[int]] = None, - ) -> None: - super().__init__() - self.depth = depth - self.num_groups = num_groups - - if kernel_size is not None and dilation is not None and padding is not None: - if ( - len(kernel_size) != depth - and len(dilation) != depth - and len(padding) != depth - ): - raise RuntimeError( - "Length of convolutional parameters does not match the depth." - ) - self.kernel_size = kernel_size - self.padding = padding - self.dilation = dilation - - else: - self.kernel_size = [3] * depth - self.padding = [1] * depth - self.dilation = [1] * depth - - self.dropout_rate = dropout_rate - self.conv = nn.Conv2d( - in_channels, base_channels, kernel_size=3, stride=1, padding=1 - ) - - channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] - self.encoder_blocks = self._configure_down_sampling_blocks( - channels, activation, pooling_kernel - ) - self.decoder_blocks = self._configure_up_sampling_blocks( - channels, activation, scale_factor - ) - - self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) - - def _configure_down_sampling_blocks( - self, channels: List[int], activation: str, pooling_kernel: int - ) -> nn.ModuleList: - blocks = nn.ModuleList([]) - for i in range(len(channels) - 1): - pooling_kernel = pooling_kernel if i < self.depth - 1 else False - dropout_rate = self.dropout_rate if i < 0 else 0 - blocks += [ - _DownSamplingBlock( - [channels[i], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - pooling_kernel, - dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - ] - - return blocks - - def _configure_up_sampling_blocks( - self, channels: List[int], activation: str, scale_factor: int, - ) -> nn.ModuleList: - channels.reverse() - self.kernel_size.reverse() - self.dilation.reverse() - self.padding.reverse() - return nn.ModuleList( - [ - _UpSamplingBlock( - [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - scale_factor, - self.dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - for i in range(len(channels) - 2) - ] - ) - - def _encode(self, x: Tensor) -> List[Tensor]: - x_skips = [] - for block in self.encoder_blocks: - x, x_skip = block(x) - x_skips.append(x_skip) - return x_skips - - def _decode(self, x_skips: List[Tensor]) -> Tensor: - x = x_skips[-1] - for i, block in enumerate(self.decoder_blocks): - x = block(x, x_skips[-(i + 2)]) - return x - - def forward(self, x: Tensor) -> Tensor: - """Forward pass with the UNet model.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - x = self.conv(x) - x_skips = self._encode(x) - x = self._decode(x_skips) - return self.head(x) diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py deleted file mode 100644 index efb3701..0000000 --- a/text_recognizer/networks/vit.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A Vision Transformer. - -Inspired by: -https://openreview.net/pdf?id=YicbFdNTTy - -""" -from typing import Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import Transformer - - -class ViT(nn.Module): - """Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - expansion_dim: int, - patch_dim: Tuple[int, int], - image_size: Tuple[int, int], - dropout_rate: float, - trg_pad_index: int, - max_len: int, - activation: str = "gelu", - ) -> None: - super().__init__() - - self.trg_pad_index = trg_pad_index - self.patch_dim = patch_dim - self.num_patches = image_size[-1] // self.patch_dim[1] - - # Encoder - self.patch_to_embedding = nn.Linear( - self.patch_dim[0] * self.patch_dim[1], hidden_dim - ) - self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.dropout = nn.Dropout(dropout_rate) - self._init() - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _init(self) -> None: - nn.init.normal_(self.character_embedding.weight, std=0.02) - # nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - patches = rearrange( - src, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim[0], - p2=self.patch_dim[1], - ) - - # From patches to encoded sequence. - x = self.patch_to_embedding(patches) - b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, : (n + 1)] - x = self.dropout(x) - - return x - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - _, n = trg.shape - trg = self.character_embedding(trg.long()) - trg += self.pos_embedding[:, :n] - return trg - - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) - return logits |