summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py43
-rw-r--r--src/text_recognizer/networks/beam.py83
-rw-r--r--src/text_recognizer/networks/cnn.py101
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py158
-rw-r--r--src/text_recognizer/networks/crnn.py110
-rw-r--r--src/text_recognizer/networks/ctc.py58
-rw-r--r--src/text_recognizer/networks/densenet.py225
-rw-r--r--src/text_recognizer/networks/lenet.py68
-rw-r--r--src/text_recognizer/networks/loss/__init__.py2
-rw-r--r--src/text_recognizer/networks/loss/loss.py69
-rw-r--r--src/text_recognizer/networks/metrics.py123
-rw-r--r--src/text_recognizer/networks/mlp.py73
-rw-r--r--src/text_recognizer/networks/residual_network.py310
-rw-r--r--src/text_recognizer/networks/stn.py44
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py208
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
-rw-r--r--src/text_recognizer/networks/transducer/transducer.py410
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transformer/attention.py93
-rw-r--r--src/text_recognizer/networks/transformer/positional_encoding.py32
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py264
-rw-r--r--src/text_recognizer/networks/unet.py255
-rw-r--r--src/text_recognizer/networks/util.py89
-rw-r--r--src/text_recognizer/networks/vit.py150
-rw-r--r--src/text_recognizer/networks/vq_transformer.py150
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py5
-rw-r--r--src/text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py147
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py119
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
-rw-r--r--src/text_recognizer/networks/wide_resnet.py221
32 files changed, 0 insertions, 3883 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
deleted file mode 100644
index 1521355..0000000
--- a/src/text_recognizer/networks/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-"""Network modules."""
-from .cnn import CNN
-from .cnn_transformer import CNNTransformer
-from .crnn import ConvolutionalRecurrentNetwork
-from .ctc import greedy_decoder
-from .densenet import DenseNet
-from .lenet import LeNet
-from .metrics import accuracy, cer, wer
-from .mlp import MLP
-from .residual_network import ResidualNetwork, ResidualNetworkEncoder
-from .transducer import load_transducer_loss, TDS2d
-from .transformer import Transformer
-from .unet import UNet
-from .util import sliding_window
-from .vit import ViT
-from .vq_transformer import VQTransformer
-from .vqvae import VQVAE
-from .wide_resnet import WideResidualNetwork
-
-__all__ = [
- "accuracy",
- "cer",
- "CNN",
- "CNNTransformer",
- "ConvolutionalRecurrentNetwork",
- "DenseNet",
- "FCN",
- "greedy_decoder",
- "MLP",
- "LeNet",
- "load_transducer_loss",
- "ResidualNetwork",
- "ResidualNetworkEncoder",
- "sliding_window",
- "UNet",
- "TDS2d",
- "Transformer",
- "ViT",
- "VQTransformer",
- "VQVAE",
- "wer",
- "WideResidualNetwork",
-]
diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py
deleted file mode 100644
index dccccdb..0000000
--- a/src/text_recognizer/networks/beam.py
+++ /dev/null
@@ -1,83 +0,0 @@
-"""Implementation of beam search decoder for a sequence to sequence network.
-
-Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
-
-"""
-# from typing import List
-# from Queue import PriorityQueue
-
-# from loguru import logger
-# import torch
-# from torch import nn
-# from torch import Tensor
-# import torch.nn.functional as F
-
-
-# class Node:
-# def __init__(
-# self, parent: Node, target_index: int, log_prob: Tensor, length: int
-# ) -> None:
-# self.parent = parent
-# self.target_index = target_index
-# self.log_prob = log_prob
-# self.length = length
-# self.reward = 0.0
-
-# def eval(self, alpha: float = 1.0) -> Tensor:
-# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
-
-
-# @torch.no_grad()
-# def beam_decoder(
-# network, mapper, device, memory: Tensor = None, max_len: int = 97,
-# ) -> Tensor:
-# beam_width = 10
-# topk = 1 # How many sentences to generate.
-
-# trg_indices = [mapper(mapper.init_token)]
-
-# end_nodes = []
-
-# node = Node(None, trg_indices, 0, 1)
-# nodes = PriorityQueue()
-
-# nodes.put((node.eval(), node))
-# q_size = 1
-
-# # Beam search
-# for _ in range(max_len):
-# if q_size > 2000:
-# logger.warning("Could not decoder input")
-# break
-
-# # Fetch the best node.
-# score, n = nodes.get()
-# decoder_input = n.target_index
-
-# if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
-# end_nodes.append((score, n))
-
-# # If we reached the maximum number of sentences required.
-# if len(end_nodes) >= 1:
-# break
-# else:
-# continue
-
-# # Forward pass with transformer.
-# trg = torch.tensor(trg_indices, device=device)[None, :].long()
-# trg = network.target_embedding(trg)
-# logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
-# log_prob = F.log_softmax(logits, dim=2)
-
-# log_prob, indices = torch.topk(log_prob, beam_width)
-
-# for new_k in range(beam_width):
-# # TODO: continue from here
-# token_index = indices[0][new_k].view(1, -1)
-# log_p = log_prob[0][new_k].item()
-
-# node = Node()
-
-# pass
-
-# pass
diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py
deleted file mode 100644
index 1807bb9..0000000
--- a/src/text_recognizer/networks/cnn.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""Implementation of a simple backbone cnn network."""
-from typing import Callable, Dict, Optional, Tuple
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class CNN(nn.Module):
- """LeNet network for character prediction."""
-
- def __init__(
- self,
- channels: Tuple[int, ...] = (1, 32, 64, 128),
- kernel_sizes: Tuple[int, ...] = (4, 4, 4),
- strides: Tuple[int, ...] = (2, 2, 2),
- max_pool_kernel: int = 2,
- dropout_rate: float = 0.2,
- activation: Optional[str] = "relu",
- ) -> None:
- """Initialization of the LeNet network.
-
- Args:
- channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
- kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
- strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
- max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
- dropout_rate (float): The dropout rate. Defaults to 0.2.
- activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
-
- Raises:
- RuntimeError: if the number of hyperparameters does not match in length.
-
- """
- super().__init__()
-
- if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
- raise RuntimeError("The number of the hyperparameters does not match.")
-
- self.cnn = self._build_network(
- channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
- )
-
- def _build_network(
- self,
- channels: Tuple[int, ...],
- kernel_sizes: Tuple[int, ...],
- strides: Tuple[int, ...],
- max_pool_kernel: int,
- dropout_rate: float,
- activation: str,
- ) -> nn.Sequential:
- # Load activation function.
- activation_fn = activation_function(activation)
-
- channels = list(channels)
- in_channels = channels.pop(0)
- configuration = zip(channels, kernel_sizes, strides)
-
- modules = nn.ModuleList([])
-
- for i, (out_channels, kernel_size, stride) in enumerate(configuration):
- # Add max pool to reduce output size.
- if i == len(channels) // 2:
- modules.append(nn.MaxPool2d(max_pool_kernel))
- if i == 0:
- modules.append(
- nn.Conv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=1
- )
- )
- else:
- modules.append(
- nn.Sequential(
- activation_fn,
- nn.BatchNorm2d(in_channels),
- nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=1,
- ),
- )
- )
-
- if dropout_rate:
- modules.append(nn.Dropout2d(p=dropout_rate))
-
- in_channels = out_channels
-
- return nn.Sequential(*modules)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.cnn(x)
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
deleted file mode 100644
index a2d7926..0000000
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ /dev/null
@@ -1,158 +0,0 @@
-"""A CNN-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
-
-
-class CNNTransformer(nn.Module):
- """CNN+Transfomer for image to sequence prediction."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- pool_kernel: Optional[Tuple[int, int]] = None,
- ) -> None:
- super().__init__()
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.backbone = configure_backbone(backbone, backbone_args)
-
- if pool_kernel is not None:
- self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
- else:
- self.max_pool = None
-
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.pos_dropout = nn.Dropout(p=dropout_rate)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
- )
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(
- # nn.Linear(hidden_dim, hidden_dim * 2),
- # activation_function(activation),
- nn.Linear(hidden_dim, vocab_size),
- )
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: A input src to the transformer.
-
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- src = self.backbone(src)
-
- if self.max_pool is not None:
- src = self.max_pool(src)
-
- if self.adaptive_pool is not None and len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- elif len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b (h w) c")
-
- b, t, _ = src.shape
-
- src += self.src_position_embedding[:, :t]
- src = self.pos_dropout(src)
-
- return src
-
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
-
- """
- trg = self.character_embedding(trg.long())
- trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
- logits = self.head(out)
- return logits
-
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
deleted file mode 100644
index 778e232..0000000
--- a/src/text_recognizer/networks/crnn.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""CRNN for handwritten text recognition."""
-from typing import Dict, Tuple
-
-from einops import rearrange, reduce
-from einops.layers.torch import Rearrange
-from loguru import logger
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import configure_backbone
-
-
-class ConvolutionalRecurrentNetwork(nn.Module):
- """Network that takes a image of a text line and predicts tokens that are in the image."""
-
- def __init__(
- self,
- backbone: str,
- backbone_args: Dict = None,
- input_size: int = 128,
- hidden_size: int = 128,
- bidirectional: bool = False,
- num_layers: int = 1,
- num_classes: int = 80,
- patch_size: Tuple[int, int] = (28, 28),
- stride: Tuple[int, int] = (1, 14),
- recurrent_cell: str = "lstm",
- avg_pool: bool = False,
- use_sliding_window: bool = True,
- ) -> None:
- super().__init__()
- self.backbone_args = backbone_args or {}
- self.patch_size = patch_size
- self.stride = stride
- self.sliding_window = (
- self._configure_sliding_window() if use_sliding_window else None
- )
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.backbone = configure_backbone(backbone, backbone_args)
- self.bidirectional = bidirectional
- self.avg_pool = avg_pool
-
- if recurrent_cell.upper() in ["LSTM", "GRU"]:
- recurrent_cell = getattr(nn, recurrent_cell)
- else:
- logger.warning(
- f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
- )
- recurrent_cell = nn.LSTM
-
- self.rnn = recurrent_cell(
- input_size=self.input_size,
- hidden_size=self.hidden_size,
- bidirectional=bidirectional,
- num_layers=num_layers,
- )
-
- decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
-
- self.decoder = nn.Sequential(
- nn.Linear(in_features=decoder_size, out_features=num_classes),
- nn.LogSoftmax(dim=2),
- )
-
- def _configure_sliding_window(self) -> nn.Sequential:
- return nn.Sequential(
- nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
- Rearrange(
- "b (c h w) t -> b t c h w",
- h=self.patch_size[0],
- w=self.patch_size[1],
- c=1,
- ),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
-
- if self.sliding_window is not None:
- # Create image patches with a sliding window kernel.
- x = self.sliding_window(x)
-
- # Rearrange from a sequence of patches for feedforward network.
- b, t = x.shape[:2]
- x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
-
- x = self.backbone(x)
-
- # Average pooling.
- if self.avg_pool:
- x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
- else:
- x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
- else:
- # Encode the entire image with a CNN, and use the channels as temporal dimension.
- x = self.backbone(x)
- x = rearrange(x, "b c h w -> b w c h")
- if self.adaptive_pool is not None:
- x = self.adaptive_pool(x)
- x = x.squeeze(3)
-
- # Sequence predictions.
- x, _ = self.rnn(x)
-
- # Sequence to classification layer.
- x = self.decoder(x)
- return x
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
deleted file mode 100644
index af9b700..0000000
--- a/src/text_recognizer/networks/ctc.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Decodes the CTC output."""
-from typing import Callable, List, Optional, Tuple
-
-from einops import rearrange
-import torch
-from torch import Tensor
-
-from text_recognizer.datasets.util import EmnistMapper
-
-
-def greedy_decoder(
- predictions: Tensor,
- targets: Optional[Tensor] = None,
- target_lengths: Optional[Tensor] = None,
- character_mapper: Optional[Callable] = None,
- blank_label: int = 79,
- collapse_repeated: bool = True,
-) -> Tuple[List[str], List[str]]:
- """Greedy CTC decoder.
-
- Args:
- predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
- targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
- target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
- character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
- to None.
- blank_label (int): The blank character to be ignored. Defaults to 80.
- collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
-
- Returns:
- Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
-
- """
-
- if character_mapper is None:
- character_mapper = EmnistMapper(pad_token="_") # noqa: S106
-
- predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
- decoded_predictions = []
- decoded_targets = []
- for i, prediction in enumerate(predictions):
- decoded_prediction = []
- decoded_target = []
- if targets is not None and target_lengths is not None:
- for target_index in targets[i][: target_lengths[i]]:
- if target_index == blank_label:
- continue
- decoded_target.append(character_mapper(int(target_index)))
- decoded_targets.append(decoded_target)
- for j, index in enumerate(prediction):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == prediction[j - 1]:
- continue
- decoded_prediction.append(index.item())
- decoded_predictions.append(
- [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
- )
- return decoded_predictions, decoded_targets
diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py
deleted file mode 100644
index 7dc58d9..0000000
--- a/src/text_recognizer/networks/densenet.py
+++ /dev/null
@@ -1,225 +0,0 @@
-"""Defines a Densely Connected Convolutional Networks in PyTorch.
-
-Sources:
-https://arxiv.org/abs/1608.06993
-https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
-
-"""
-from typing import List, Optional, Union
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class _DenseLayer(nn.Module):
- """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
-
- def __init__(
- self,
- in_channels: int,
- growth_rate: int,
- bn_size: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.dense_layer = [
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=bn_size * growth_rate,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.BatchNorm2d(bn_size * growth_rate),
- activation_fn,
- nn.Conv2d(
- in_channels=bn_size * growth_rate,
- out_channels=growth_rate,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- ]
- if dropout_rate:
- self.dense_layer.append(nn.Dropout(p=dropout_rate))
-
- self.dense_layer = nn.Sequential(*self.dense_layer)
-
- def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
- if isinstance(x, list):
- x = torch.cat(x, 1)
- return self.dense_layer(x)
-
-
-class _DenseBlock(nn.Module):
- def __init__(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.dense_block = self._build_dense_blocks(
- num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
- )
-
- def _build_dense_blocks(
- self,
- num_layers: int,
- in_channels: int,
- bn_size: int,
- growth_rate: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> nn.ModuleList:
- dense_block = []
- for i in range(num_layers):
- dense_block.append(
- _DenseLayer(
- in_channels=in_channels + i * growth_rate,
- growth_rate=growth_rate,
- bn_size=bn_size,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- return nn.ModuleList(dense_block)
-
- def forward(self, x: Tensor) -> Tensor:
- feature_maps = [x]
- for layer in self.dense_block:
- x = layer(feature_maps)
- feature_maps.append(x)
- return torch.cat(feature_maps, 1)
-
-
-class _Transition(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, activation: str = "relu",
- ) -> None:
- super().__init__()
- activation_fn = activation_function(activation)
- self.transition = nn.Sequential(
- nn.BatchNorm2d(in_channels),
- activation_fn,
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- stride=1,
- bias=False,
- ),
- nn.AvgPool2d(kernel_size=2, stride=2),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.transition(x)
-
-
-class DenseNet(nn.Module):
- """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
-
- def __init__(
- self,
- growth_rate: int = 32,
- block_config: List[int] = (6, 12, 24, 16),
- in_channels: int = 1,
- base_channels: int = 64,
- num_classes: int = 80,
- bn_size: int = 4,
- dropout_rate: float = 0,
- classifier: bool = True,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.densenet = self._configure_densenet(
- in_channels,
- base_channels,
- num_classes,
- growth_rate,
- block_config,
- bn_size,
- dropout_rate,
- classifier,
- activation,
- )
-
- def _configure_densenet(
- self,
- in_channels: int,
- base_channels: int,
- num_classes: int,
- growth_rate: int,
- block_config: List[int],
- bn_size: int,
- dropout_rate: float,
- classifier: bool,
- activation: str,
- ) -> nn.Sequential:
- activation_fn = activation_function(activation)
- densenet = [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=base_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- ),
- nn.BatchNorm2d(base_channels),
- activation_fn,
- ]
-
- num_features = base_channels
-
- for i, num_layers in enumerate(block_config):
- densenet.append(
- _DenseBlock(
- num_layers=num_layers,
- in_channels=num_features,
- bn_size=bn_size,
- growth_rate=growth_rate,
- dropout_rate=dropout_rate,
- activation=activation,
- )
- )
- num_features = num_features + num_layers * growth_rate
- if i != len(block_config) - 1:
- densenet.append(
- _Transition(
- in_channels=num_features,
- out_channels=num_features // 2,
- activation=activation,
- )
- )
- num_features = num_features // 2
-
- densenet.append(activation_fn)
-
- if classifier:
- densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
- densenet.append(Rearrange("b c h w -> b (c h w)"))
- densenet.append(
- nn.Linear(in_features=num_features, out_features=num_classes)
- )
-
- return nn.Sequential(*densenet)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass of Densenet."""
- # If batch dimenstion is missing, it will be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.densenet(x)
diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py
deleted file mode 100644
index 527e1a0..0000000
--- a/src/text_recognizer/networks/lenet.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""Implementation of the LeNet network."""
-from typing import Callable, Dict, Optional, Tuple
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class LeNet(nn.Module):
- """LeNet network for character prediction."""
-
- def __init__(
- self,
- channels: Tuple[int, ...] = (1, 32, 64),
- kernel_sizes: Tuple[int, ...] = (3, 3, 2),
- hidden_size: Tuple[int, ...] = (9216, 128),
- dropout_rate: float = 0.2,
- num_classes: int = 10,
- activation_fn: Optional[str] = "relu",
- ) -> None:
- """Initialization of the LeNet network.
-
- Args:
- channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
- kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
- hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
- Defaults to (9216, 128).
- dropout_rate (float): The dropout rate. Defaults to 0.2.
- num_classes (int): Number of classes. Defaults to 10.
- activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
-
- """
- super().__init__()
-
- activation_fn = activation_function(activation_fn)
-
- self.layers = [
- nn.Conv2d(
- in_channels=channels[0],
- out_channels=channels[1],
- kernel_size=kernel_sizes[0],
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=channels[1],
- out_channels=channels[2],
- kernel_size=kernel_sizes[1],
- ),
- activation_fn,
- nn.MaxPool2d(kernel_sizes[2]),
- nn.Dropout(p=dropout_rate),
- Rearrange("b c h w -> b (c h w)"),
- nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
- activation_fn,
- nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=hidden_size[1], out_features=num_classes),
- ]
-
- self.layers = nn.Sequential(*self.layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.layers(x)
diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py
deleted file mode 100644
index b489264..0000000
--- a/src/text_recognizer/networks/loss/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Loss module."""
-from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
diff --git a/src/text_recognizer/networks/loss/loss.py b/src/text_recognizer/networks/loss/loss.py
deleted file mode 100644
index cf9fa0d..0000000
--- a/src/text_recognizer/networks/loss/loss.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-import torch
-from torch import nn
-from torch import Tensor
-from torch.autograd import Variable
-import torch.nn.functional as F
-
-__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss
-
-
-class LabelSmoothingCrossEntropy(nn.Module):
- """Label smoothing loss function."""
-
- def __init__(
- self,
- classes: int,
- smoothing: float = 0.0,
- ignore_index: int = None,
- dim: int = -1,
- ) -> None:
- super().__init__()
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- self.ignore_index = ignore_index
- self.cls = classes
- self.dim = dim
-
- def forward(self, pred: Tensor, target: Tensor) -> Tensor:
- """Calculates the loss."""
- pred = pred.log_softmax(dim=self.dim)
- with torch.no_grad():
- # true_dist = pred.data.clone()
- true_dist = torch.zeros_like(pred)
- true_dist.fill_(self.smoothing / (self.cls - 1))
- true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
- if self.ignore_index is not None:
- true_dist[:, self.ignore_index] = 0
- mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
- if mask.dim() > 0:
- true_dist.index_fill_(0, mask.squeeze(), 0.0)
- return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
deleted file mode 100644
index 2605731..0000000
--- a/src/text_recognizer/networks/metrics.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Utility functions for models."""
-from typing import Optional
-
-from einops import rearrange
-import Levenshtein as Lev
-import torch
-from torch import Tensor
-
-from text_recognizer.networks import greedy_decoder
-
-
-def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
- """Computes the accuracy.
-
- Args:
- outputs (Tensor): The output from the network.
- labels (Tensor): Ground truth labels.
- pad_index (int): Padding index.
-
- Returns:
- float: The accuracy for the batch.
-
- """
-
- _, predicted = torch.max(outputs, dim=-1)
-
- # Mask out the pad tokens
- mask = labels != pad_index
-
- predicted *= mask
- labels *= mask
-
- acc = (predicted == labels).sum().float() / labels.shape[0]
- acc = acc.item()
- return acc
-
-
-def cer(
- outputs: Tensor,
- targets: Tensor,
- batch_size: Optional[int] = None,
- blank_label: Optional[int] = int,
-) -> float:
- """Computes the character error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
- batch_size (Optional[int]): Batch size if target and output has been flattend.
- blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
-
- Returns:
- float: The cer for the batch.
-
- """
- if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
- targets = rearrange(targets, "(b t) -> b t", b=batch_size)
- outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
-
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths, blank_label=blank_label,
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
- prediction, target = (
- prediction.replace(" ", ""),
- target.replace(" ", ""),
- )
- lev_dist += Lev.distance(prediction, target)
- return lev_dist / len(decoded_predictions)
-
-
-def wer(
- outputs: Tensor,
- targets: Tensor,
- batch_size: Optional[int] = None,
- blank_label: Optional[int] = int,
-) -> float:
- """Computes the Word error rate.
-
- Args:
- outputs (Tensor): The output from the network.
- targets (Tensor): Ground truth labels.
- batch_size (optional[int]): Batch size if target and output has been flattend.
- blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
-
- Returns:
- float: The wer for the batch.
-
- """
- if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
- targets = rearrange(targets, "(b t) -> b t", b=batch_size)
- outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
-
- target_lengths = torch.full(
- size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
- )
- decoded_predictions, decoded_targets = greedy_decoder(
- outputs, targets, target_lengths, blank_label=blank_label,
- )
-
- lev_dist = 0
-
- for prediction, target in zip(decoded_predictions, decoded_targets):
- prediction = "".join(prediction)
- target = "".join(target)
-
- b = set(prediction.split() + target.split())
- word2char = dict(zip(b, range(len(b))))
-
- w1 = [chr(word2char[w]) for w in prediction.split()]
- w2 = [chr(word2char[w]) for w in target.split()]
-
- lev_dist += Lev.distance("".join(w1), "".join(w2))
-
- return lev_dist / len(decoded_predictions)
diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py
deleted file mode 100644
index 1101912..0000000
--- a/src/text_recognizer/networks/mlp.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Defines the MLP network."""
-from typing import Callable, Dict, List, Optional, Union
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-
-from text_recognizer.networks.util import activation_function
-
-
-class MLP(nn.Module):
- """Multi layered perceptron network."""
-
- def __init__(
- self,
- input_size: int = 784,
- num_classes: int = 10,
- hidden_size: Union[int, List] = 128,
- num_layers: int = 3,
- dropout_rate: float = 0.2,
- activation_fn: str = "relu",
- ) -> None:
- """Initialization of the MLP network.
-
- Args:
- input_size (int): The input shape of the network. Defaults to 784.
- num_classes (int): Number of classes in the dataset. Defaults to 10.
- hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
- num_layers (int): The number of hidden layers. Defaults to 3.
- dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
- activation_fn (str): Name of the activation function in the hidden layers. Defaults to
- relu.
-
- """
- super().__init__()
-
- activation_fn = activation_function(activation_fn)
-
- if isinstance(hidden_size, int):
- hidden_size = [hidden_size] * num_layers
-
- self.layers = [
- Rearrange("b c h w -> b (c h w)"),
- nn.Linear(in_features=input_size, out_features=hidden_size[0]),
- activation_fn,
- ]
-
- for i in range(num_layers - 1):
- self.layers += [
- nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]),
- activation_fn,
- ]
-
- if dropout_rate:
- self.layers.append(nn.Dropout(p=dropout_rate))
-
- self.layers.append(
- nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
- )
-
- self.layers = nn.Sequential(*self.layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """The feedforward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- return self.layers(x)
-
- @property
- def __name__(self) -> str:
- """Returns the name of the network."""
- return "mlp"
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
deleted file mode 100644
index c33f419..0000000
--- a/src/text_recognizer/networks/residual_network.py
+++ /dev/null
@@ -1,310 +0,0 @@
-"""Residual CNN."""
-from functools import partial
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from einops.layers.torch import Rearrange, Reduce
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class Conv2dAuto(nn.Conv2d):
- """Convolution with auto padding based on kernel size."""
-
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
-
-
-def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
- """3x3 convolution with batch norm."""
- conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
- return nn.Sequential(
- conv3x3(in_channels, out_channels, *args, **kwargs),
- nn.BatchNorm2d(out_channels),
- )
-
-
-class IdentityBlock(nn.Module):
- """Residual with identity block."""
-
- def __init__(
- self, in_channels: int, out_channels: int, activation: str = "relu"
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.blocks = nn.Identity()
- self.activation_fn = activation_function(activation)
- self.shortcut = nn.Identity()
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- residual = x
- if self.apply_shortcut:
- residual = self.shortcut(x)
- x = self.blocks(x)
- x += residual
- x = self.activation_fn(x)
- return x
-
- @property
- def apply_shortcut(self) -> bool:
- """Check if shortcut should be applied."""
- return self.in_channels != self.out_channels
-
-
-class ResidualBlock(IdentityBlock):
- """Residual with nonlinear shortcut."""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- expansion: int = 1,
- downsampling: int = 1,
- *args,
- **kwargs
- ) -> None:
- """Short summary.
-
- Args:
- in_channels (int): Number of in channels.
- out_channels (int): umber of out channels.
- expansion (int): Expansion factor of the out channels. Defaults to 1.
- downsampling (int): Downsampling factor used in stride. Defaults to 1.
- *args (type): Extra arguments.
- **kwargs (type): Extra key value arguments.
-
- """
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.expansion = expansion
- self.downsampling = downsampling
-
- self.shortcut = (
- nn.Sequential(
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.expanded_channels,
- kernel_size=1,
- stride=self.downsampling,
- bias=False,
- ),
- nn.BatchNorm2d(self.expanded_channels),
- )
- if self.apply_shortcut
- else None
- )
-
- @property
- def expanded_channels(self) -> int:
- """Computes the expanded output channels."""
- return self.out_channels * self.expansion
-
- @property
- def apply_shortcut(self) -> bool:
- """Check if shortcut should be applied."""
- return self.in_channels != self.expanded_channels
-
-
-class BasicBlock(ResidualBlock):
- """Basic ResNet block."""
-
- expansion = 1
-
- def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.blocks = nn.Sequential(
- conv_bn(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- bias=False,
- stride=self.downsampling,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.expanded_channels,
- bias=False,
- ),
- )
-
-
-class BottleNeckBlock(ResidualBlock):
- """Bottleneck block to increase depth while minimizing parameter size."""
-
- expansion = 4
-
- def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
- super().__init__(in_channels, out_channels, *args, **kwargs)
- self.blocks = nn.Sequential(
- conv_bn(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=1,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.out_channels,
- kernel_size=3,
- stride=self.downsampling,
- ),
- self.activation_fn,
- conv_bn(
- in_channels=self.out_channels,
- out_channels=self.expanded_channels,
- kernel_size=1,
- ),
- )
-
-
-class ResidualLayer(nn.Module):
- """ResNet layer."""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- block: BasicBlock = BasicBlock,
- num_blocks: int = 1,
- *args,
- **kwargs
- ) -> None:
- super().__init__()
- downsampling = 2 if in_channels != out_channels else 1
- self.blocks = nn.Sequential(
- block(
- in_channels, out_channels, *args, **kwargs, downsampling=downsampling
- ),
- *[
- block(
- out_channels * block.expansion,
- out_channels,
- downsampling=1,
- *args,
- **kwargs
- )
- for _ in range(num_blocks - 1)
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- x = self.blocks(x)
- return x
-
-
-class ResidualNetworkEncoder(nn.Module):
- """Encoder network."""
-
- def __init__(
- self,
- in_channels: int = 1,
- block_sizes: Union[int, List[int]] = (32, 64),
- depths: Union[int, List[int]] = (2, 2),
- activation: str = "relu",
- block: Type[nn.Module] = BasicBlock,
- levels: int = 1,
- *args,
- **kwargs
- ) -> None:
- super().__init__()
- self.block_sizes = (
- block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
- )
- self.depths = depths if isinstance(depths, list) else [depths] * levels
- self.activation = activation
- self.gate = nn.Sequential(
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=self.block_sizes[0],
- kernel_size=7,
- stride=2,
- padding=1,
- bias=False,
- ),
- nn.BatchNorm2d(self.block_sizes[0]),
- activation_function(self.activation),
- # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
- )
-
- self.blocks = self._configure_blocks(block)
-
- def _configure_blocks(
- self, block: Type[nn.Module], *args, **kwargs
- ) -> nn.Sequential:
- channels = [self.block_sizes[0]] + list(
- zip(self.block_sizes, self.block_sizes[1:])
- )
- blocks = [
- ResidualLayer(
- in_channels=channels[0],
- out_channels=channels[0],
- num_blocks=self.depths[0],
- block=block,
- activation=self.activation,
- *args,
- **kwargs
- )
- ]
- blocks += [
- ResidualLayer(
- in_channels=in_channels * block.expansion,
- out_channels=out_channels,
- num_blocks=num_blocks,
- block=block,
- activation=self.activation,
- *args,
- **kwargs
- )
- for (in_channels, out_channels), num_blocks in zip(
- channels[1:], self.depths[1:]
- )
- ]
-
- return nn.Sequential(*blocks)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- # If batch dimenstion is missing, it needs to be added.
- if len(x.shape) == 3:
- x = x.unsqueeze(0)
- x = self.gate(x)
- x = self.blocks(x)
- return x
-
-
-class ResidualNetworkDecoder(nn.Module):
- """Classification head."""
-
- def __init__(self, in_features: int, num_classes: int = 80) -> None:
- super().__init__()
- self.decoder = nn.Sequential(
- Reduce("b c h w -> b c", "mean"),
- nn.Linear(in_features=in_features, out_features=num_classes),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- return self.decoder(x)
-
-
-class ResidualNetwork(nn.Module):
- """Full residual network."""
-
- def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
- super().__init__()
- self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
- self.decoder = ResidualNetworkDecoder(
- in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
- num_classes=num_classes,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- x = self.encoder(x)
- x = self.decoder(x)
- return x
diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py
deleted file mode 100644
index e9d216f..0000000
--- a/src/text_recognizer/networks/stn.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Spatial Transformer Network."""
-
-from einops.layers.torch import Rearrange
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-
-class SpatialTransformerNetwork(nn.Module):
- """A network with differentiable attention.
-
- Network that learns how to perform spatial transformations on the input image in order to enhance the
- geometric invariance of the model.
-
- # TODO: add arguments to make it more general.
-
- """
-
- def __init__(self) -> None:
- super().__init__()
- # Initialize the identity transformation and its weights and biases.
- linear = nn.Linear(32, 3 * 2)
- linear.weight.data.zero_()
- linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
-
- self.theta = nn.Sequential(
- nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.ReLU(inplace=True),
- Rearrange("b c h w -> b (c h w)", h=3, w=3),
- nn.Linear(in_features=10 * 3 * 3, out_features=32),
- nn.ReLU(inplace=True),
- linear,
- Rearrange("b (row col) -> b row col", row=2, col=3),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """The spatial transformation."""
- grid = F.affine_grid(self.theta(x), x.shape)
- return F.grid_sample(x, grid, align_corners=False)
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
deleted file mode 100644
index 8c19a01..0000000
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transducer modules."""
-from .tds_conv import TDS2d
-from .transducer import load_transducer_loss, Transducer
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
deleted file mode 100644
index 5fb8ba9..0000000
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Time-Depth Separable Convolutions.
-
-References:
- https://arxiv.org/abs/1904.02619
- https://arxiv.org/pdf/2010.01003.pdf
-
-Code stolen from:
- https://github.com/facebookresearch/gtn_applications
-
-
-"""
-from typing import List, Tuple
-
-from einops import rearrange
-import gtn
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class TDSBlock2d(nn.Module):
- """Internal block of a 2D TDSC network."""
-
- def __init__(
- self,
- in_channels: int,
- img_depth: int,
- kernel_size: Tuple[int],
- dropout_rate: float,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.img_depth = img_depth
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
- self.fc_dim = in_channels * img_depth
-
- # Network placeholders.
- self.conv = None
- self.mlp = None
- self.instance_norm = None
-
- self._build_block()
-
- def _build_block(self) -> None:
- # Convolutional block.
- self.conv = nn.Sequential(
- nn.Conv3d(
- in_channels=self.in_channels,
- out_channels=self.in_channels,
- kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
- padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- )
-
- # MLP block.
- self.mlp = nn.Sequential(
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.Dropout(self.dropout_rate),
- )
-
- # Instance norm.
- self.instance_norm = nn.ModuleList(
- [
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, CD, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- B, CD, H, W = x.shape
- C, D = self.in_channels, self.img_depth
- residual = x
- x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
- x = self.conv(x)
- x = rearrange(x, "b c d h w -> b (c d) h w")
- x += residual
-
- x = self.instance_norm[0](x)
-
- x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
- x + self.instance_norm[1](x)
-
- # Output shape: [B, CD, H, W]
- return x
-
-
-class TDS2d(nn.Module):
- """TDS Netowrk.
-
- Structure is the following:
- Downsample layer -> TDS2d group -> ... -> Linear output layer
-
-
- """
-
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- depth: int,
- tds_groups: Tuple[int],
- kernel_size: Tuple[int],
- dropout_rate: float,
- in_channels: int = 1,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.depth = depth
- self.tds_groups = tds_groups
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
-
- self.tds = None
- self.fc = None
-
- self._build_network()
-
- def _build_network(self) -> None:
- in_channels = self.in_channels
- modules = []
- stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
- if self.input_dim % stride_h:
- raise RuntimeError(
- f"Image height not divisible by total stride {stride_h}."
- )
-
- for tds_group in self.tds_groups:
- # Add downsample layer.
- out_channels = self.depth * tds_group["channels"]
- modules.extend(
- [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=self.kernel_size,
- padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- stride=tds_group["stride"],
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.InstanceNorm2d(out_channels, affine=True),
- ]
- )
-
- for _ in range(tds_group["num_blocks"]):
- modules.append(
- TDSBlock2d(
- tds_group["channels"],
- self.depth,
- self.kernel_size,
- self.dropout_rate,
- )
- )
-
- in_channels = out_channels
-
- self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- if len(x.shape) == 4:
- x = x.squeeze(1) # Squeeze the channel dim away.
-
- B, H, W = x.shape
- x = rearrange(
- x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
- )
- x = self.tds(x)
-
- # x shape: [B, C, H, W]
- x = rearrange(x, "b c h w -> b w (c h)")
-
- return self.fc(x)
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
deleted file mode 100644
index cadcecc..0000000
--- a/src/text_recognizer/networks/transducer/test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch
-from torch import nn
-
-from text_recognizer.networks.transducer import load_transducer_loss, Transducer
-import unittest
-
-
-class TestTransducer(unittest.TestCase):
- def test_viterbi(self):
- T = 5
- N = 4
- B = 2
-
- # fmt: off
- emissions1 = torch.tensor((
- 0, 4, 0, 1,
- 0, 2, 1, 1,
- 0, 0, 0, 2,
- 0, 0, 0, 2,
- 8, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- emissions2 = torch.tensor((
- 0, 2, 1, 7,
- 0, 2, 9, 1,
- 0, 0, 0, 2,
- 0, 0, 5, 2,
- 1, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- # fmt: on
-
- # Test without blank:
- labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
- transducer = Transducer(
- tokens=["a", "b", "c", "d"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
- blank="none",
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
- # Test with blank without repeats:
- labels = [[1, 0], [2, 2]]
- transducer = Transducer(
- tokens=["a", "b", "c"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2},
- blank="optional",
- allow_repeats=False,
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py
deleted file mode 100644
index d7e3d08..0000000
--- a/src/text_recognizer/networks/transducer/transducer.py
+++ /dev/null
@@ -1,410 +0,0 @@
-"""Transducer and the transducer loss function.py
-
-Stolen from:
- https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
-
-"""
-from pathlib import Path
-import itertools
-from typing import Dict, List, Optional, Union, Tuple
-
-from loguru import logger
-import gtn
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-
-
-def make_scalar_graph(weight) -> gtn.Graph:
- scalar = gtn.Graph()
- scalar.add_node(True)
- scalar.add_node(False, True)
- scalar.add_arc(0, 1, 0, 0, weight)
- return scalar
-
-
-def make_chain_graph(sequence) -> gtn.Graph:
- graph = gtn.Graph(False)
- graph.add_node(True)
- for i, s in enumerate(sequence):
- graph.add_node(False, i == (len(sequence) - 1))
- graph.add_arc(i, i + 1, s)
- return graph
-
-
-def make_transitions_graph(
- ngram: int, num_tokens: int, calc_grad: bool = False
-) -> gtn.Graph:
- transitions = gtn.Graph(calc_grad)
- transitions.add_node(True, ngram == 1)
-
- state_map = {(): 0}
-
- # First build transitions which include <s>:
- for n in range(1, ngram):
- for state in itertools.product(range(num_tokens), repeat=n):
- in_idx = state_map[state[:-1]]
- out_idx = transitions.add_node(False, ngram == 1)
- state_map[state] = out_idx
- transitions.add_arc(in_idx, out_idx, state[-1])
-
- for state in itertools.product(range(num_tokens), repeat=ngram):
- state_idx = state_map[state[:-1]]
- new_state_idx = state_map[state[1:]]
- # p(state[-1] | state[:-1])
- transitions.add_arc(state_idx, new_state_idx, state[-1])
-
- if ngram > 1:
- # Build transitions which include </s>:
- end_idx = transitions.add_node(False, True)
- for in_idx in range(end_idx):
- transitions.add_arc(in_idx, end_idx, gtn.epsilon)
-
- return transitions
-
-
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
- """Constructs a graph which transduces letters to word pieces."""
- graph = gtn.Graph(False)
- graph.add_node(True, True)
- for i, wp in enumerate(word_pieces):
- prev = 0
- for l in wp[:-1]:
- n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
- prev = n
- graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
- graph.arc_sort()
- return graph
-
-
-def make_token_graph(
- token_list: List, blank: str = "none", allow_repeats: bool = True
-) -> gtn.Graph:
- """Constructs a graph with all the individual token transition models."""
- if not allow_repeats and blank != "optional":
- raise ValueError("Must use blank='optional' if disallowing repeats.")
-
- ntoks = len(token_list)
- graph = gtn.Graph(False)
-
- # Creating nodes
- graph.add_node(True, True)
- for i in range(ntoks):
- # We can consume one or more consecutive word
- # pieces for each emission:
- # E.g. [ab, ab, ab] transduces to [ab]
- graph.add_node(False, blank != "forced")
-
- if blank != "none":
- graph.add_node()
-
- # Creating arcs
- if blank != "none":
- # Blank index is assumed to be last (ntoks)
- graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
- graph.add_arc(ntoks + 1, 0, gtn.epsilon)
-
- for i in range(ntoks):
- graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
- graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
-
- if allow_repeats:
- if blank == "forced":
- # Allow transitions from token to blank only
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- else:
- # Allow transition from token to blank and all other tokens
- graph.add_arc(i + 1, 0, gtn.epsilon)
-
- else:
- # allow transitions to blank and all other tokens except the same token
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- for j in range(ntoks):
- if i != j:
- graph.add_arc(i + 1, j + 1, j, j)
-
- return graph
-
-
-class TransducerLossFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- inputs,
- targets,
- tokens,
- lexicon,
- transition_params=None,
- transitions=None,
- reduction="none",
- ) -> Tensor:
- B, T, C = inputs.shape
-
- losses = [None] * B
- emissions_graphs = [None] * B
-
- if transitions is not None:
- if transition_params is None:
- raise ValueError("Specified transitions, but not transition params.")
-
- cpu_data = transition_params.cpu().contiguous()
- transitions.set_weights(cpu_data.data_ptr())
- transitions.calc_grad = transition_params.requires_grad
- transitions.zero_grad()
-
- def process(b: int) -> None:
- # Create emission graph:
- emissions = gtn.linear_graph(T, C, inputs.requires_grad)
- cpu_data = inputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
- target = make_chain_graph(targets[b])
- target.arc_sort(True)
-
- # Create token tot grapheme decomposition graph
- tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
- tokens_target.arc_sort()
-
- # Create alignment graph:
- aligments = gtn.project_input(
- gtn.remove(gtn.compose(tokens, tokens_target))
- )
- aligments.arc_sort()
-
- # Add transitions scores:
- if transitions is not None:
- aligments = gtn.intersect(transitions, aligments)
- aligments.arc_sort()
-
- loss = gtn.forward_score(gtn.intersect(emissions, aligments))
-
- # Normalize if needed:
- if transitions is not None:
- norm = gtn.forward_score(gtn.intersect(emissions, transitions))
- loss = gtn.subtract(loss, norm)
-
- losses[b] = gtn.negate(loss)
-
- # Save for backward:
- if emissions.calc_grad:
- emissions_graphs[b] = emissions
-
- gtn.parallel_for(process, range(B))
-
- ctx.graphs = (losses, emissions_graphs, transitions)
- ctx.input_shape = inputs.shape
-
- # Optionally reduce by target length
- if reduction == "mean":
- scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
- else:
- scales = [1.0] * B
-
- ctx.scales = scales
-
- loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
- return torch.mean(loss.to(inputs.device))
-
- @staticmethod
- def backward(ctx, grad_output) -> Tuple:
- losses, emissions_graphs, transitions = ctx.graphs
- scales = ctx.scales
-
- B, T, C = ctx.input_shape
- calc_emissions = ctx.needs_input_grad[0]
- input_grad = torch.empty((B, T, C)) if calc_emissions else None
-
- def process(b: int) -> None:
- scale = make_scalar_graph(scales[b])
- gtn.backward(losses[b], scale)
- emissions = emissions_graphs[b]
- if calc_emissions:
- grad = emissions.grad().weights_to_numpy()
- input_grad[b] = torch.tensor(grad).view(1, T, C)
-
- gtn.parallel_for(process, range(B))
-
- if calc_emissions:
- input_grad = input_grad.to(grad_output.device)
- input_grad *= grad_output / B
-
- if ctx.needs_input_grad[4]:
- grad = transitions.grad().weights_to_numpy()
- transition_grad = torch.tensor(grad).to(grad_output.device)
- transition_grad *= grad_output / B
- else:
- transition_grad = None
-
- return (
- input_grad,
- None, # target
- None, # tokens
- None, # lexicon
- transition_grad, # transition params
- None, # transitions graph
- None,
- )
-
-
-TransducerLoss = TransducerLossFunction.apply
-
-
-class Transducer(nn.Module):
- def __init__(
- self,
- tokens: List,
- graphemes_to_idx: Dict,
- ngram: int = 0,
- transitions: str = None,
- blank: str = "none",
- allow_repeats: bool = True,
- reduction: str = "none",
- ) -> None:
- """A generic transducer loss function.
-
- Args:
- tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
- representing the output tokens of the model (e.g. letters,
- word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
- could be a list of sub-word tokens.
- graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
- "a", "b", ..) to their corresponding integer index.
- ngram (int) : Order of the token-level transition model. If `ngram=0`
- then no transition model is used.
- blank (string) : Specifies the usage of blank token
- 'none' - do not use blank token
- 'optional' - allow an optional blank inbetween tokens
- 'forced' - force a blank inbetween tokens (also referred to as garbage token)
- allow_repeats (boolean) : If false, then we don't allow paths with
- consecutive tokens in the alignment graph. This keeps the graph
- unambiguous in the sense that the same input cannot transduce to
- different outputs.
- """
- super().__init__()
- if blank not in ["optional", "forced", "none"]:
- raise ValueError(
- "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
- )
- self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
- self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
- self.ngram = ngram
- if ngram > 0 and transitions is not None:
- raise ValueError("Only one of ngram and transitions may be specified")
-
- if ngram > 0:
- transitions = make_transitions_graph(
- ngram, len(tokens) + int(blank != "none"), True
- )
-
- if transitions is not None:
- self.transitions = transitions
- self.transitions.arc_sort()
- self.transitions_params = nn.Parameter(
- torch.zeros(self.transitions.num_arcs())
- )
- else:
- self.transitions = None
- self.transitions_params = None
- self.reduction = reduction
-
- def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
- TransducerLoss(
- inputs,
- targets,
- self.tokens,
- self.lexicon,
- self.transitions_params,
- self.transitions,
- self.reduction,
- )
-
- def viterbi(self, outputs: Tensor) -> List[Tensor]:
- B, T, C = outputs.shape
-
- if self.transitions is not None:
- cpu_data = self.transition_params.cpu().contiguous()
- self.transitions.set_weights(cpu_data.data_ptr())
- self.transitions.calc_grad = False
-
- self.tokens.arc_sort()
-
- paths = [None] * B
-
- def process(b: int) -> None:
- emissions = gtn.linear_graph(T, C, False)
- cpu_data = outputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
-
- if self.transitions is not None:
- full_graph = gtn.intersect(emissions, self.transitions)
- else:
- full_graph = emissions
-
- # Find the best path and remove back-off arcs:
- path = gtn.remove(gtn.viterbi_path(full_graph))
-
- # Left compose the viterbi path with the "aligment to token"
- # transducer to get the outputs:
- path = gtn.compose(path, self.tokens)
-
- # When there are ambiguous paths (allow_repeats is true), we take
- # the shortest:
- path = gtn.viterbi_path(path)
- path = gtn.remove(gtn.project_output(path))
- paths[b] = path.labels_to_list()
-
- gtn.parallel_for(process, range(B))
- predictions = [torch.IntTensor(path) for path in paths]
- return predictions
-
-
-def load_transducer_loss(
- num_features: int,
- ngram: int,
- tokens: str,
- lexicon: str,
- transitions: str,
- blank: str,
- allow_repeats: bool,
- prepend_wordsep: bool = False,
- use_words: bool = False,
- data_dir: Optional[Union[str, Path]] = None,
- reduction: str = "mean",
-) -> Tuple[Transducer, int]:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if transitions is not None:
- transitions = gtn.load(str(processed_path / transitions))
-
- preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
- )
-
- num_tokens = preprocessor.num_tokens
-
- criterion = Transducer(
- preprocessor.tokens,
- preprocessor.graphemes_to_index,
- ngram=ngram,
- transitions=transitions,
- blank=blank,
- allow_repeats=allow_repeats,
- reduction=reduction,
- )
-
- return criterion, num_tokens + int(blank != "none")
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
deleted file mode 100644
index 9febc88..0000000
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transformer modules."""
-from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py
deleted file mode 100644
index cce1ecc..0000000
--- a/src/text_recognizer/networks/transformer/attention.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""Implementes the attention module for the transformer."""
-from typing import Optional, Tuple
-
-from einops import rearrange
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class MultiHeadAttention(nn.Module):
- """Implementation of multihead attention."""
-
- def __init__(
- self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
- self.fc_q = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_k = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_v = nn.Linear(
- in_features=hidden_dim, out_features=hidden_dim, bias=False
- )
- self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
-
- self._init_weights()
-
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def _init_weights(self) -> None:
- nn.init.normal_(
- self.fc_q.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_k.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.normal_(
- self.fc_v.weight,
- mean=0,
- std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
- )
- nn.init.xavier_normal_(self.fc_out.weight)
-
- def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tensor:
- """Calculates the scaled dot product attention."""
-
- # Compute the energy.
- energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
- query.shape[-1]
- )
-
- # If we have a mask for padding some inputs.
- if mask is not None:
- energy = energy.masked_fill(mask == 0, -np.inf)
-
- # Compute the attention from the energy.
- attention = torch.softmax(energy, dim=3)
-
- out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
- out = rearrange(out, "b head l v -> b l (head v)")
- return out, attention
-
- def forward(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
- ) -> Tuple[Tensor, Tensor]:
- """Forward pass for computing the multihead attention."""
- # Get the query, key, and value tensor.
- query = rearrange(
- self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
- )
- key = rearrange(
- self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
- )
- value = rearrange(
- self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
- )
-
- out, attention = self.scaled_dot_product_attention(query, key, value, mask)
-
- out = self.fc_out(out)
- out = self.dropout(out)
- return out, attention
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py
deleted file mode 100644
index 1ba5537..0000000
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class PositionalEncoding(nn.Module):
- """Encodes a sense of distance or time for transformer networks."""
-
- def __init__(
- self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
- ) -> None:
- super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- self.max_len = max_len
-
- pe = torch.zeros(max_len, hidden_dim)
- position = torch.arange(0, max_len).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
- )
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer("pe", pe)
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
- return self.dropout(x)
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
deleted file mode 100644
index dd180c4..0000000
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ /dev/null
@@ -1,264 +0,0 @@
-"""Transfomer module."""
-import copy
-from typing import Dict, Optional, Type, Union
-
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.attention import MultiHeadAttention
-from text_recognizer.networks.util import activation_function
-
-
-class GEGLU(nn.Module):
- """GLU activation for improving feedforward activations."""
-
- def __init__(self, dim_in: int, dim_out: int) -> None:
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward propagation."""
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
- return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-
-
-class _IntraLayerConnection(nn.Module):
- """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-
- def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward(self, src: Tensor, residual: Tensor) -> Tensor:
- return self.norm(self.dropout(src) + residual)
-
-
-class _ConvolutionalLayer(nn.Module):
- def __init__(
- self,
- hidden_dim: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- in_projection = (
- nn.Sequential(
- nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
- )
- if activation != "glu"
- else GEGLU(hidden_dim, expansion_dim)
- )
-
- self.layer = nn.Sequential(
- in_projection,
- nn.Dropout(p=dropout_rate),
- nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layer(x)
-
-
-class EncoderLayer(nn.Module):
- """Transfomer encoding layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through the encoder."""
- # First block.
- # Multi head attention.
- out, _ = self.self_attention(src, src, src, mask)
-
- # Add & norm.
- out = self.block1(out, src)
-
- # Second block.
- # Apply 1D-convolution.
- cnn_out = self.cnn(out)
-
- # Add & norm.
- out = self.block2(cnn_out, out)
-
- return out
-
-
-class Encoder(nn.Module):
- """Transfomer encoder module."""
-
- def __init__(
- self,
- num_layers: int,
- encoder_layer: Type[nn.Module],
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.norm = norm
-
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
- """Forward pass through all encoder layers."""
- for layer in self.layers:
- src = layer(src, src_mask)
-
- if self.norm is not None:
- src = self.norm(src)
-
- return src
-
-
-class DecoderLayer(nn.Module):
- """Transfomer decoder layer."""
-
- def __init__(
- self,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float = 0.0,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
- self.multihead_attention = MultiHeadAttention(
- hidden_dim, num_heads, dropout_rate
- )
- self.cnn = _ConvolutionalLayer(
- hidden_dim, expansion_dim, dropout_rate, activation
- )
- self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
- self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass of the layer."""
- out, _ = self.self_attention(trg, trg, trg, trg_mask)
- trg = self.block1(out, trg)
-
- out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
- trg = self.block2(out, trg)
-
- out = self.cnn(trg)
- out = self.block3(out, trg)
-
- return out
-
-
-class Decoder(nn.Module):
- """Transfomer decoder module."""
-
- def __init__(
- self,
- decoder_layer: Type[nn.Module],
- num_layers: int,
- norm: Optional[Type[nn.Module]] = None,
- ) -> None:
- super().__init__()
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- trg: Tensor,
- memory: Tensor,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the decoder."""
- for layer in self.layers:
- trg = layer(trg, memory, trg_mask, memory_mask)
-
- if self.norm is not None:
- trg = self.norm(trg)
-
- return trg
-
-
-class Transformer(nn.Module):
- """Transformer network."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- num_heads: int,
- expansion_dim: int,
- dropout_rate: float,
- activation: str = "relu",
- ) -> None:
- super().__init__()
-
- # Configure encoder.
- encoder_norm = nn.LayerNorm(hidden_dim)
- encoder_layer = EncoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-
- # Configure decoder.
- decoder_norm = nn.LayerNorm(hidden_dim)
- decoder_layer = DecoderLayer(
- hidden_dim, num_heads, expansion_dim, dropout_rate, activation
- )
- self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-
- self._reset_parameters()
-
- def _reset_parameters(self) -> None:
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
-
- def forward(
- self,
- src: Tensor,
- trg: Tensor,
- src_mask: Optional[Tensor] = None,
- trg_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Forward pass through the transformer."""
- if src.shape[0] != trg.shape[0]:
- print(trg.shape)
- raise RuntimeError("The batch size of the src and trg must be the same.")
- if src.shape[2] != trg.shape[2]:
- raise RuntimeError(
- "The number of features for the src and trg must be the same."
- )
-
- memory = self.encoder(src, src_mask)
- output = self.decoder(trg, memory, trg_mask, memory_mask)
- return output
diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py
deleted file mode 100644
index 510910f..0000000
--- a/src/text_recognizer/networks/unet.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""UNet for segmentation."""
-from typing import List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-class _ConvBlock(nn.Module):
- """Modified UNet convolutional block with dilation."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.channels = channels
- self.dropout_rate = dropout_rate
- self.kernel_size = kernel_size
- self.dilation = dilation
- self.padding = padding
- self.num_groups = num_groups
- self.activation = activation_function(activation)
- self.block = self._configure_block()
- self.residual_conv = nn.Sequential(
- nn.Conv2d(
- self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
- ),
- self.activation,
- )
-
- def _configure_block(self) -> nn.Sequential:
- block = []
- for i in range(len(self.channels) - 1):
- block += [
- nn.Dropout(p=self.dropout_rate),
- nn.GroupNorm(self.num_groups, self.channels[i]),
- self.activation,
- nn.Conv2d(
- self.channels[i],
- self.channels[i + 1],
- kernel_size=self.kernel_size,
- padding=self.padding,
- stride=1,
- dilation=self.dilation,
- ),
- ]
-
- return nn.Sequential(*block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the convolutional block."""
- residual = self.residual_conv(x)
- return self.block(x) + residual
-
-
-class _DownSamplingBlock(nn.Module):
- """Basic down sampling block."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- pooling_kernel: Union[int, bool] = 2,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.conv_block = _ConvBlock(
- channels,
- activation,
- num_groups,
- dropout_rate,
- kernel_size,
- dilation,
- padding,
- )
- self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Return the convolutional block output and a down sampled tensor."""
- x = self.conv_block(x)
- x_down = self.down_sampling(x) if self.down_sampling is not None else x
-
- return x_down, x
-
-
-class _UpSamplingBlock(nn.Module):
- """The upsampling block of the UNet."""
-
- def __init__(
- self,
- channels: List[int],
- activation: str,
- num_groups: int,
- scale_factor: int = 2,
- dropout_rate: float = 0.1,
- kernel_size: int = 3,
- dilation: int = 1,
- padding: int = 0,
- ) -> None:
- super().__init__()
- self.conv_block = _ConvBlock(
- channels,
- activation,
- num_groups,
- dropout_rate,
- kernel_size,
- dilation,
- padding,
- )
- self.up_sampling = nn.Upsample(
- scale_factor=scale_factor, mode="bilinear", align_corners=True
- )
-
- def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor:
- """Apply the up sampling and convolutional block."""
- x = self.up_sampling(x)
- if x_skip is not None:
- x = torch.cat((x, x_skip), dim=1)
- return self.conv_block(x)
-
-
-class UNet(nn.Module):
- """UNet architecture."""
-
- def __init__(
- self,
- in_channels: int = 1,
- base_channels: int = 64,
- num_classes: int = 3,
- depth: int = 4,
- activation: str = "relu",
- num_groups: int = 8,
- dropout_rate: float = 0.1,
- pooling_kernel: int = 2,
- scale_factor: int = 2,
- kernel_size: Optional[List[int]] = None,
- dilation: Optional[List[int]] = None,
- padding: Optional[List[int]] = None,
- ) -> None:
- super().__init__()
- self.depth = depth
- self.num_groups = num_groups
-
- if kernel_size is not None and dilation is not None and padding is not None:
- if (
- len(kernel_size) != depth
- and len(dilation) != depth
- and len(padding) != depth
- ):
- raise RuntimeError(
- "Length of convolutional parameters does not match the depth."
- )
- self.kernel_size = kernel_size
- self.padding = padding
- self.dilation = dilation
-
- else:
- self.kernel_size = [3] * depth
- self.padding = [1] * depth
- self.dilation = [1] * depth
-
- self.dropout_rate = dropout_rate
- self.conv = nn.Conv2d(
- in_channels, base_channels, kernel_size=3, stride=1, padding=1
- )
-
- channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
- self.encoder_blocks = self._configure_down_sampling_blocks(
- channels, activation, pooling_kernel
- )
- self.decoder_blocks = self._configure_up_sampling_blocks(
- channels, activation, scale_factor
- )
-
- self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
-
- def _configure_down_sampling_blocks(
- self, channels: List[int], activation: str, pooling_kernel: int
- ) -> nn.ModuleList:
- blocks = nn.ModuleList([])
- for i in range(len(channels) - 1):
- pooling_kernel = pooling_kernel if i < self.depth - 1 else False
- dropout_rate = self.dropout_rate if i < 0 else 0
- blocks += [
- _DownSamplingBlock(
- [channels[i], channels[i + 1], channels[i + 1]],
- activation,
- self.num_groups,
- pooling_kernel,
- dropout_rate,
- self.kernel_size[i],
- self.dilation[i],
- self.padding[i],
- )
- ]
-
- return blocks
-
- def _configure_up_sampling_blocks(
- self, channels: List[int], activation: str, scale_factor: int,
- ) -> nn.ModuleList:
- channels.reverse()
- self.kernel_size.reverse()
- self.dilation.reverse()
- self.padding.reverse()
- return nn.ModuleList(
- [
- _UpSamplingBlock(
- [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
- activation,
- self.num_groups,
- scale_factor,
- self.dropout_rate,
- self.kernel_size[i],
- self.dilation[i],
- self.padding[i],
- )
- for i in range(len(channels) - 2)
- ]
- )
-
- def _encode(self, x: Tensor) -> List[Tensor]:
- x_skips = []
- for block in self.encoder_blocks:
- x, x_skip = block(x)
- x_skips.append(x_skip)
- return x_skips
-
- def _decode(self, x_skips: List[Tensor]) -> Tensor:
- x = x_skips[-1]
- for i, block in enumerate(self.decoder_blocks):
- x = block(x, x_skips[-(i + 2)])
- return x
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass with the UNet model."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- x = self.conv(x)
- x_skips = self._encode(x)
- x = self._decode(x_skips)
- return self.head(x)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
deleted file mode 100644
index 131a6b4..0000000
--- a/src/text_recognizer/networks/util.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""Miscellaneous neural network functionality."""
-import importlib
-from pathlib import Path
-from typing import Dict, Tuple, Type
-
-from einops import rearrange
-from loguru import logger
-import torch
-from torch import nn
-
-
-def sliding_window(
- images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
-) -> torch.Tensor:
- """Creates patches of an image.
-
- Args:
- images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
- patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
- stride (Tuple[int, int]): The stride of the sliding window.
-
- Returns:
- torch.Tensor: A tensor with the shape (batch, patches, height, width).
-
- """
- unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
- # Preform the sliding window, unsqueeze as the channel dimesion is lost.
- c = images.shape[1]
- patches = unfold(images)
- patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
- )
- return patches
-
-
-def activation_function(activation: str) -> Type[nn.Module]:
- """Returns the callable activation function."""
- activation_fns = nn.ModuleDict(
- [
- ["elu", nn.ELU(inplace=True)],
- ["gelu", nn.GELU()],
- ["glu", nn.GLU()],
- ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
- ["none", nn.Identity()],
- ["relu", nn.ReLU(inplace=True)],
- ["selu", nn.SELU(inplace=True)],
- ]
- )
- return activation_fns[activation.lower()]
-
-
-def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
- """Loads a backbone network."""
- network_module = importlib.import_module("text_recognizer.networks")
- backbone_ = getattr(network_module, backbone)
-
- if "pretrained" in backbone_args:
- logger.info("Loading pretrained backbone.")
- checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
- "pretrained"
- )
-
- # Loading state directory.
- state_dict = torch.load(checkpoint_file)
- network_args = state_dict["network_args"]
- weights = state_dict["model_state"]
-
- freeze = False
- if "freeze" in backbone_args and backbone_args["freeze"] is True:
- backbone_args.pop("freeze")
- freeze = True
- network_args = backbone_args
-
- # Initializes the network with trained weights.
- backbone = backbone_(**network_args)
- backbone.load_state_dict(weights)
- if freeze:
- for params in backbone.parameters():
- params.requires_grad = False
- else:
- backbone_ = getattr(network_module, backbone)
- backbone = backbone_(**backbone_args)
-
- if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
- backbone = nn.Sequential(
- *list(backbone.children())[:][: -backbone_args["remove_layers"]]
- )
-
- return backbone
diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py
deleted file mode 100644
index efb3701..0000000
--- a/src/text_recognizer/networks/vit.py
+++ /dev/null
@@ -1,150 +0,0 @@
-"""A Vision Transformer.
-
-Inspired by:
-https://openreview.net/pdf?id=YicbFdNTTy
-
-"""
-from typing import Optional, Tuple
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import Transformer
-
-
-class ViT(nn.Module):
- """Transfomer for image to sequence prediction."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- expansion_dim: int,
- patch_dim: Tuple[int, int],
- image_size: Tuple[int, int],
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- activation: str = "gelu",
- ) -> None:
- super().__init__()
-
- self.trg_pad_index = trg_pad_index
- self.patch_dim = patch_dim
- self.num_patches = image_size[-1] // self.patch_dim[1]
-
- # Encoder
- self.patch_to_embedding = nn.Linear(
- self.patch_dim[0] * self.patch_dim[1], hidden_dim
- )
- self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.dropout = nn.Dropout(dropout_rate)
- self._init()
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
-
- def _init(self) -> None:
- nn.init.normal_(self.character_embedding.weight, std=0.02)
- # nn.init.normal_(self.pos_embedding.weight, std=0.02)
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: A input src to the transformer.
-
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- patches = rearrange(
- src,
- "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
- p1=self.patch_dim[0],
- p2=self.patch_dim[1],
- )
-
- # From patches to encoded sequence.
- x = self.patch_to_embedding(patches)
- b, n, _ = x.shape
- cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
- x = torch.cat((cls_tokens, x), dim=1)
- x += self.pos_embedding[:, : (n + 1)]
- x = self.dropout(x)
-
- return x
-
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
-
- """
- _, n = trg.shape
- trg = self.character_embedding(trg.long())
- trg += self.pos_embedding[:, :n]
- return trg
-
- def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(h, trg, trg_mask=trg_mask)
-
- logits = self.head(out)
- return logits
-
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
- logits = self.decode_image_features(h, trg)
- return logits
diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py
deleted file mode 100644
index c673d96..0000000
--- a/src/text_recognizer/networks/vq_transformer.py
+++ /dev/null
@@ -1,150 +0,0 @@
-"""A VQ-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
-from text_recognizer.networks.vqvae.encoder import _ResidualBlock
-
-
-class VQTransformer(nn.Module):
- """VQ+Transfomer for image to character sequence prediction."""
-
- def __init__(
- self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- ) -> None:
- super().__init__()
-
- # Configure vector quantized backbone.
- self.backbone = configure_backbone(backbone, backbone_args)
- self.conv = nn.Sequential(
- nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
- nn.ReLU(inplace=True),
- )
-
- # Configure embeddings for Transformer network.
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
- )
-
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
- )
-
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
-
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
- )
-
- def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
-
- Args:
- src (Tensor): Input tensor.
-
- Returns:
- Tensor: The input src to the transformer and the vq loss.
-
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
- src, vq_loss = self.backbone.encode(src)
- # src = self.backbone.decoder.res_block(src)
- src = self.conv(src)
-
- if self.adaptive_pool is not None:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- else:
- src = rearrange(src, "b c h w -> b (w h) c")
-
- b, t, _ = src.shape
-
- src += self.src_position_embedding[:, :t]
-
- return src, vq_loss
-
- def target_embedding(self, trg: Tensor) -> Tensor:
- """Encodes target tensor with embedding and postion.
-
- Args:
- trg (Tensor): Target tensor.
-
- Returns:
- Tensor: Encoded target tensor.
-
- """
- trg = self.character_embedding(trg.long())
- trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
- logits = self.head(out)
- return logits
-
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features, vq_loss = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
deleted file mode 100644
index 763953c..0000000
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""VQ-VAE module."""
-from .decoder import Decoder
-from .encoder import Encoder
-from .vector_quantizer import VectorQuantizer
-from .vqvae import VQVAE
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py
deleted file mode 100644
index 8847aba..0000000
--- a/src/text_recognizer/networks/vqvae/decoder.py
+++ /dev/null
@@ -1,133 +0,0 @@
-"""CNN decoder for the VQ-VAE."""
-
-from typing import List, Optional, Tuple, Type
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.encoder import _ResidualBlock
-
-
-class Decoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- upsampling: Optional[List[List[int]]] = None,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.upsampling = upsampling
-
- self.res_block = nn.ModuleList([])
- self.upsampling_block = nn.ModuleList([])
-
- self.embedding_dim = embedding_dim
- activation = activation_function(activation)
-
- # Configure encoder.
- self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
- )
-
- def _build_decompression_block(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for i, (out_channels, kernel_size, stride) in enumerate(configuration):
- modules.append(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=1,
- ),
- activation,
- )
- )
-
- if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
-
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- modules.extend(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
- ),
- nn.Tanh(),
- )
- )
-
- return modules
-
- def _build_decoder(
- self,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
-
- self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
- )
-
- # Bottleneck module.
- self.res_block.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[0], channels[0], dropout)
- for i in range(num_residual_layers)
- ]
- )
- )
-
- # Decompression module
- self.upsampling_block.extend(
- self._build_decompression_block(
- channels[0], channels[1:], kernel_sizes, strides, activation, dropout
- )
- )
-
- self.res_block = nn.Sequential(*self.res_block)
- self.upsampling_block = nn.Sequential(*self.upsampling_block)
-
- return nn.Sequential(self.res_block, self.upsampling_block)
-
- def forward(self, z_q: Tensor) -> Tensor:
- """Reconstruct input from given codes."""
- x_reconstruction = self.decoder(z_q)
- return x_reconstruction
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
deleted file mode 100644
index d3adac5..0000000
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ /dev/null
@@ -1,147 +0,0 @@
-"""CNN encoder for the VQ-VAE."""
-from typing import List, Optional, Tuple, Type
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
-
-
-class _ResidualBlock(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
- ) -> None:
- super().__init__()
- self.block = [
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
- ]
-
- if dropout is not None:
- self.block.append(dropout)
-
- self.block = nn.Sequential(*self.block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
-
-
-class Encoder(nn.Module):
- """A CNN encoder network."""
-
- def __init__(
- self,
- in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- num_embeddings: int,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.embedding_dim = embedding_dim
- self.num_embeddings = num_embeddings
- self.beta = beta
- activation = activation_function(activation)
-
- # Configure encoder.
- self.encoder = self._build_encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
- )
-
- # Configure Vector Quantizer.
- self.vector_quantizer = VectorQuantizer(
- self.num_embeddings, self.embedding_dim, self.beta
- )
-
- def _build_compression_block(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for out_channels, kernel_size, stride in configuration:
- modules.append(
- nn.Sequential(
- nn.Conv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=1
- ),
- activation,
- )
- )
-
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- return modules
-
- def _build_encoder(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
- encoder = nn.ModuleList([])
-
- # compression module
- encoder.extend(
- self._build_compression_block(
- in_channels, channels, kernel_sizes, strides, activation, dropout
- )
- )
-
- # Bottleneck module.
- encoder.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[-1], channels[-1], dropout)
- for i in range(num_residual_layers)
- ]
- )
- )
-
- encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
- )
-
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input into a discrete representation."""
- z_e = self.encoder(x)
- z_q, vq_loss = self.vector_quantizer(z_e)
- return z_q, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
deleted file mode 100644
index f92c7ee..0000000
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ /dev/null
@@ -1,119 +0,0 @@
-"""Implementation of a Vector Quantized Variational AutoEncoder.
-
-Reference:
-https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-
-"""
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-from torch.nn import functional as F
-
-
-class VectorQuantizer(nn.Module):
- """The codebook that contains quantized vectors."""
-
- def __init__(
- self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
- ) -> None:
- super().__init__()
- self.K = num_embeddings
- self.D = embedding_dim
- self.beta = beta
-
- self.embedding = nn.Embedding(self.K, self.D)
-
- # Initialize the codebook.
- nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
-
- def discretization_bottleneck(self, latent: Tensor) -> Tensor:
- """Computes the code nearest to the latent representation.
-
- First we compute the posterior categorical distribution, and then map
- the latent representation to the nearest element of the embedding.
-
- Args:
- latent (Tensor): The latent representation.
-
- Shape:
- - latent :math:`(B x H x W, D)`
-
- Returns:
- Tensor: The quantized embedding vector.
-
- """
- # Store latent shape.
- b, h, w, d = latent.shape
-
- # Flatten the latent representation to 2D.
- latent = rearrange(latent, "b h w d -> (b h w) d")
-
- # Compute the L2 distance between the latents and the embeddings.
- l2_distance = (
- torch.sum(latent ** 2, dim=1, keepdim=True)
- + torch.sum(self.embedding.weight ** 2, dim=1)
- - 2 * latent @ self.embedding.weight.t()
- ) # [BHW x K]
-
- # Find the embedding k nearest to each latent.
- encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
-
- # Convert to one-hot encodings, aka discrete bottleneck.
- one_hot_encoding = torch.zeros(
- encoding_indices.shape[0], self.K, device=latent.device
- )
- one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
-
- # Embedding quantization.
- quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
- quantized_latent = rearrange(
- quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
- )
-
- return quantized_latent
-
- def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
- """Vector Quantization loss.
-
- The vector quantization algorithm allows us to create a codebook. The VQ
- algorithm works by moving the embedding vectors towards the encoder outputs.
-
- The embedding loss moves the embedding vector towards the encoder outputs. The
- .detach() works as the stop gradient (sg) described in the paper.
-
- Because the volume of the embedding space is dimensionless, it can arbitarily
- grow if the embeddings are not trained as fast as the encoder parameters. To
- mitigate this, a commitment loss is added in the second term which makes sure
- that the encoder commits to an embedding and that its output does not grow.
-
- Args:
- latent (Tensor): The encoder output.
- quantized_latent (Tensor): The quantized latent.
-
- Returns:
- Tensor: The combinded VQ loss.
-
- """
- embedding_loss = F.mse_loss(quantized_latent, latent.detach())
- commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
- return embedding_loss + self.beta * commitment_loss
-
- def forward(self, latent: Tensor) -> Tensor:
- """Forward pass that returns the quantized vector and the vq loss."""
- # Rearrange latent representation s.t. the hidden dim is at the end.
- latent = rearrange(latent, "b d h w -> b h w d")
-
- # Maps latent to the nearest code in the codebook.
- quantized_latent = self.discretization_bottleneck(latent)
-
- loss = self.vq_loss(latent, quantized_latent)
-
- # Add residue to the quantized latent.
- quantized_latent = latent + (quantized_latent - latent).detach()
-
- # Rearrange the quantized shape back to the original shape.
- quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
-
- return quantized_latent, loss
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 50448b4..0000000
--- a/src/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""The VQ-VAE."""
-
-from typing import List, Optional, Tuple, Type
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.vqvae import Decoder, Encoder
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- num_embeddings: int,
- upsampling: Optional[List[List[int]]] = None,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- # configure encoder.
- self.encoder = Encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- num_embeddings,
- beta,
- activation,
- dropout_rate,
- )
-
- # Configure decoder.
- channels.reverse()
- kernel_sizes.reverse()
- strides.reverse()
- self.decoder = Decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- upsampling,
- activation,
- dropout_rate,
- )
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- z_q, vq_loss = self.encode(x)
- x_reconstruction = self.decode(z_q)
- return x_reconstruction, vq_loss
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
deleted file mode 100644
index b767778..0000000
--- a/src/text_recognizer/networks/wide_resnet.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""Wide Residual CNN."""
-from functools import partial
-from typing import Callable, Dict, List, Optional, Type, Union
-
-from einops.layers.torch import Reduce
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.util import activation_function
-
-
-def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
- """Helper function for a 3x3 2d convolution."""
- return nn.Conv2d(
- in_channels=in_planes,
- out_channels=out_planes,
- kernel_size=3,
- stride=stride,
- padding=1,
- bias=False,
- )
-
-
-def conv_init(module: Type[nn.Module]) -> None:
- """Initializes the weights for convolution and batchnorms."""
- classname = module.__class__.__name__
- if classname.find("Conv") != -1:
- nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
- nn.init.constant_(module.bias, 0)
- elif classname.find("BatchNorm") != -1:
- nn.init.constant_(module.weight, 1)
- nn.init.constant_(module.bias, 0)
-
-
-class WideBlock(nn.Module):
- """Block used in WideResNet."""
-
- def __init__(
- self,
- in_planes: int,
- out_planes: int,
- dropout_rate: float,
- stride: int = 1,
- activation: str = "relu",
- ) -> None:
- super().__init__()
- self.in_planes = in_planes
- self.out_planes = out_planes
- self.dropout_rate = dropout_rate
- self.stride = stride
- self.activation = activation_function(activation)
-
- # Build blocks.
- self.blocks = nn.Sequential(
- nn.BatchNorm2d(self.in_planes),
- self.activation,
- conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
- nn.Dropout(p=self.dropout_rate),
- nn.BatchNorm2d(self.out_planes),
- self.activation,
- conv3x3(
- in_planes=self.out_planes,
- out_planes=self.out_planes,
- stride=self.stride,
- ),
- )
-
- self.shortcut = (
- nn.Sequential(
- nn.Conv2d(
- in_channels=self.in_planes,
- out_channels=self.out_planes,
- kernel_size=1,
- stride=self.stride,
- bias=False,
- ),
- )
- if self._apply_shortcut
- else None
- )
-
- @property
- def _apply_shortcut(self) -> bool:
- """If shortcut should be applied or not."""
- return self.stride != 1 or self.in_planes != self.out_planes
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass."""
- residual = x
- if self._apply_shortcut:
- residual = self.shortcut(x)
- x = self.blocks(x)
- x += residual
- return x
-
-
-class WideResidualNetwork(nn.Module):
- """WideResNet for character predictions.
-
- Can be used for classification or encoding of images to a latent vector.
-
- """
-
- def __init__(
- self,
- in_channels: int = 1,
- in_planes: int = 16,
- num_classes: int = 80,
- depth: int = 16,
- width_factor: int = 10,
- dropout_rate: float = 0.0,
- num_layers: int = 3,
- block: Type[nn.Module] = WideBlock,
- num_stages: Optional[List[int]] = None,
- activation: str = "relu",
- use_decoder: bool = True,
- ) -> None:
- """The initialization of the WideResNet.
-
- Args:
- in_channels (int): Number of input channels. Defaults to 1.
- in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
- num_classes (int): Number of classes. Defaults to 80.
- depth (int): Set the number of blocks to use. Defaults to 16.
- width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
- dropout_rate (float): The dropout rate. Defaults to 0.0.
- num_layers (int): Number of layers of blocks. Defaults to 3.
- block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
- num_stages (List[int]): If given, will use these channel values. Defaults to None.
- activation (str): Name of the activation to use. Defaults to "relu".
- use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
- latent vector. Defaults to True.
-
- Raises:
- RuntimeError: If the depth is not of the size `6n+4`.
-
- """
-
- super().__init__()
- if (depth - 4) % 6 != 0:
- raise RuntimeError("Wide-resnet depth should be 6n+4")
- self.in_channels = in_channels
- self.in_planes = in_planes
- self.num_classes = num_classes
- self.num_blocks = (depth - 4) // 6
- self.width_factor = width_factor
- self.num_layers = num_layers
- self.block = block
- self.dropout_rate = dropout_rate
- self.activation = activation_function(activation)
-
- if num_stages is None:
- self.num_stages = [self.in_planes] + [
- self.in_planes * 2 ** n * self.width_factor
- for n in range(self.num_layers)
- ]
- else:
- self.num_stages = [self.in_planes] + num_stages
-
- self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
- self.strides = [1] + [2] * (self.num_layers - 1)
-
- self.encoder = nn.Sequential(
- conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
- *[
- self._configure_wide_layer(
- in_planes=in_planes,
- out_planes=out_planes,
- stride=stride,
- activation=activation,
- )
- for (in_planes, out_planes), stride in zip(
- self.num_stages, self.strides
- )
- ],
- )
-
- self.decoder = (
- nn.Sequential(
- nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
- self.activation,
- Reduce("b c h w -> b c", "mean"),
- nn.Linear(
- in_features=self.num_stages[-1][-1], out_features=self.num_classes
- ),
- )
- if use_decoder
- else None
- )
-
- # self.apply(conv_init)
-
- def _configure_wide_layer(
- self, in_planes: int, out_planes: int, stride: int, activation: str
- ) -> List:
- strides = [stride] + [1] * (self.num_blocks - 1)
- planes = [out_planes] * len(strides)
- planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
- return nn.Sequential(
- *[
- self.block(
- in_planes=in_planes,
- out_planes=out_planes,
- dropout_rate=self.dropout_rate,
- stride=stride,
- activation=activation,
- )
- for (in_planes, out_planes), stride in zip(planes, strides)
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Feedforward pass."""
- if len(x.shape) < 4:
- x = x[(None,) * int(4 - len(x.shape))]
- x = self.encoder(x)
- if self.decoder is not None:
- x = self.decoder(x)
- return x