diff options
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 46 | ||||
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer_encoder.py | 73 | ||||
-rw-r--r-- | src/text_recognizer/networks/loss/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/loss/loss.py (renamed from src/text_recognizer/networks/loss.py) | 0 | ||||
-rw-r--r-- | src/text_recognizer/networks/neural_machine_reader.py | 361 | ||||
-rw-r--r-- | src/text_recognizer/networks/stn.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/util.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/vision_transformer.py | 159 | ||||
-rw-r--r-- | src/text_recognizer/networks/wide_resnet.py | 2 |
10 files changed, 212 insertions, 441 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index 6d88768..2cc1137 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,25 +1,20 @@ """Network modules.""" from .cnn_transformer import CNNTransformer -from .cnn_transformer_encoder import CNNTransformerEncoder from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet from .lenet import LeNet -from .loss import EmbeddingLoss from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .sparse_mlp import SparseMLP from .transformer import Transformer from .util import sliding_window -from .vision_transformer import VisionTransformer from .wide_resnet import WideResidualNetwork __all__ = [ "CNNTransformer", - "CNNTransformerEncoder", "ConvolutionalRecurrentNetwork", "DenseNet", - "EmbeddingLoss", "greedy_decoder", "MLP", "LeNet", @@ -28,6 +23,5 @@ __all__ = [ "sliding_window", "Transformer", "SparseMLP", - "VisionTransformer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 3da2c9f..16c7a41 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,4 +1,4 @@ -"""A DETR style transfomers but for text recognition.""" +"""A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple from einops import rearrange @@ -11,7 +11,7 @@ from text_recognizer.networks.util import configure_backbone class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR.""" + """CNN+Transfomer for image to sequence prediction.""" def __init__( self, @@ -25,22 +25,14 @@ class CNNTransformer(nn.Module): dropout_rate: float, trg_pad_index: int, backbone: str, - out_channels: int, - max_len: int, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index - self.backbone = configure_backbone(backbone, backbone_args) self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - - # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) - self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) self.adaptive_pool = ( nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -78,8 +70,12 @@ class CNNTransformer(nn.Module): self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) ) - def preprocess_input(self, src: Tensor) -> Tensor: - """Encodes src with a backbone network and a positional encoding. + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D Args: src (Tensor): Input tensor. @@ -88,29 +84,19 @@ class CNNTransformer(nn.Module): Tensor: A input src to the transformer. """ - # If batch dimenstion is missing, it needs to be added. + # If batch dimension is missing, it needs to be added. if len(src.shape) < 4: src = src[(None,) * (4 - len(src.shape))] src = self.backbone(src) - # src = self.conv(src) + src = rearrange(src, "b c h w -> b w c h") if self.adaptive_pool is not None: src = self.adaptive_pool(src) - H, W = src.shape[-2:] - src = rearrange(src, "b t h w -> b t (h w)") - - # construct positional encodings - pos = torch.cat( - [ - self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), - self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), - ], - dim=-1, - ).unsqueeze(0) - pos = rearrange(pos, "b h w l -> b l (h w)") - src = pos + 0.1 * src + src = src.squeeze(3) + src = self.position_encoding(src) + return src - def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: """Encodes target tensor with embedding and postion. Args: @@ -126,9 +112,9 @@ class CNNTransformer(nn.Module): def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - h = self.preprocess_input(x) + h = self.extract_image_features(x) trg_mask = self._create_trg_mask(trg) - trg = self.preprocess_target(trg) + trg = self.target_embedding(trg) out = self.transformer(h, trg, trg_mask=trg_mask) logits = self.head(out) diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py deleted file mode 100644 index 93626bf..0000000 --- a/src/text_recognizer/networks/cnn_transformer_encoder.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Network with a CNN backend and a transformer encoder head.""" -from typing import Dict - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformerEncoder(nn.Module): - """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict, - mlp_dim: int, - d_model: int, - nhead: int = 8, - dropout_rate: float = 0.1, - activation: str = "relu", - num_layers: int = 6, - num_classes: int = 80, - num_channels: int = 256, - max_len: int = 97, - ) -> None: - super().__init__() - self.d_model = d_model - self.nhead = nhead - self.dropout_rate = dropout_rate - self.activation = activation - self.num_layers = num_layers - - self.backbone = configure_backbone(backbone, backbone_args) - self.position_encoding = PositionalEncoding(d_model, dropout_rate) - self.encoder = self._configure_encoder() - - self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) - - self.mlp = nn.Linear(mlp_dim, d_model) - - self.head = nn.Linear(d_model, num_classes) - - def _configure_encoder(self) -> nn.TransformerEncoder: - encoder_layer = nn.TransformerEncoderLayer( - d_model=self.d_model, - nhead=self.nhead, - dropout=self.dropout_rate, - activation=self.activation, - ) - norm = nn.LayerNorm(self.d_model) - return nn.TransformerEncoder( - encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm - ) - - def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: - """Forward pass through the network.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - x = self.conv(self.backbone(x)) - x = rearrange(x, "b c h w -> b c (h w)") - x = self.mlp(x) - x = self.position_encoding(x) - x = rearrange(x, "b c h-> c b h") - x = self.encoder(x) - x = rearrange(x, "c b h-> b c h") - logits = self.head(x) - - return logits diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py new file mode 100644 index 0000000..b489264 --- /dev/null +++ b/src/text_recognizer/networks/loss/__init__.py @@ -0,0 +1,2 @@ +"""Loss module.""" +from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss/loss.py index cf9fa0d..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss.py +++ b/src/text_recognizer/networks/loss/loss.py diff --git a/src/text_recognizer/networks/neural_machine_reader.py b/src/text_recognizer/networks/neural_machine_reader.py index 540a7d2..7f8c49b 100644 --- a/src/text_recognizer/networks/neural_machine_reader.py +++ b/src/text_recognizer/networks/neural_machine_reader.py @@ -1,180 +1,201 @@ -from typing import Dict, Optional, Tuple +"""Sequence to sequence network with RNN cells.""" +# from typing import Dict, Optional, Tuple -from einops import rearrange -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor +# from einops import rearrange +# from einops.layers.torch import Rearrange +# import torch +# from torch import nn +# from torch import Tensor -from text_recognizer.networks.util import configure_backbone +# from text_recognizer.networks.util import configure_backbone -class Encoder(nn.Module): +# class Encoder(nn.Module): +# def __init__( +# self, +# embedding_dim: int, +# encoder_dim: int, +# decoder_dim: int, +# dropout_rate: float = 0.1, +# ) -> None: +# super().__init__() +# self.rnn = nn.GRU( +# input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True +# ) +# self.fc = nn.Sequential( +# nn.Linear(in_features=2 * encoder_dim, out_features=decoder_dim), nn.Tanh() +# ) +# self.dropout = nn.Dropout(p=dropout_rate) + +# def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: +# """Encodes a sequence of tensors with a bidirectional GRU. + +# Args: +# x (Tensor): A input sequence. + +# Shape: +# - x: :math:`(T, N, E)`. +# - output[0]: :math:`(T, N, 2 * E)`. +# - output[1]: :math:`(T, N, D)`. + +# where T is the sequence length, N is the batch size, E is the +# embedding/encoder dimension, and D is the decoder dimension. + +# Returns: +# Tuple[Tensor, Tensor]: The encoder output and the hidden state of the +# encoder. + +# """ + +# output, hidden = self.rnn(x) - def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, dropout_rate: float = 0.1) -> None: - super().__init__() - self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=encoder_dim, bidirectional=True) - self.fc = nn.Sequential(nn.Linear(in_features=2*encoder_dim, out_features=decoder_dim), nn.Tanh()) - self.dropout = nn.Dropout(p=dropout_rate) +# # Get the hidden state from the forward and backward rnn. +# hidden_state = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) + +# # Apply fully connected layer and tanh activation. +# hidden_state = self.fc(hidden_state) + +# return output, hidden_state + + +# class Attention(nn.Module): +# def __init__(self, encoder_dim: int, decoder_dim: int) -> None: +# super().__init__() +# self.atten = nn.Linear( +# in_features=2 * encoder_dim + decoder_dim, out_features=decoder_dim +# ) +# self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes a sequence of tensors with a bidirectional GRU. +# def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: +# """Short summary. - Args: - x (Tensor): A input sequence. +# Args: +# hidden_state (Tensor): Description of parameter `h`. +# encoder_outputs (Tensor): Description of parameter `enc_out`. - Shape: - - x: :math:`(T, N, E)`. - - output[0]: :math:`(T, N, 2 * E)`. - - output[1]: :math:`(T, N, D)`. +# Shape: +# - x: :math:`(T, N, E)`. +# - output[0]: :math:`(T, N, 2 * E)`. +# - output[1]: :math:`(T, N, D)`. + +# where T is the sequence length, N is the batch size, E is the +# embedding/encoder dimension, and D is the decoder dimension. + +# Returns: +# Tensor: Description of returned object. + +# """ +# t, b = enc_out.shape[:2] +# # repeat decoder hidden state src_len times +# hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) + +# encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") + +# # Calculate the energy between the decoders previous hidden state and the +# # encoders hidden states. +# energy = torch.tanh( +# self.attn(torch.cat((hidden_state, encoder_outputs), dim=2)) +# ) + +# attention = self.value(energy).squeeze(2) + +# # Apply softmax on the attention to squeeze it between 0 and 1. +# attention = F.softmax(attention, dim=1) + +# return attention + + +# class Decoder(nn.Module): +# def __init__( +# self, +# embedding_dim: int, +# encoder_dim: int, +# decoder_dim: int, +# output_dim: int, +# dropout_rate: float = 0.1, +# ) -> None: +# super().__init__() +# self.output_dim = output_dim +# self.embedding = nn.Embedding(output_dim, embedding_dim) +# self.attention = Attention(encoder_dim, decoder_dim) +# self.rnn = nn.GRU( +# input_size=2 * encoder_dim + embedding_dim, hidden_size=decoder_dim +# ) + +# self.head = nn.Linear( +# in_features=2 * encoder_dim + embedding_dim + decoder_dim, +# out_features=output_dim, +# ) +# self.dropout = nn.Dropout(p=dropout_rate) + +# def forward( +# self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor +# ) -> Tensor: +# # input = [batch size] +# # hidden = [batch size, dec hid dim] +# # encoder_outputs = [src len, batch size, enc hid dim * 2] +# trg = trg.unsqueeze(0) +# trg_embedded = self.dropout(self.embedding(trg)) - where T is the sequence length, N is the batch size, E is the - embedding/encoder dimension, and D is the decoder dimension. +# a = self.attention(hidden_state, encoder_outputs) + +# weighted = torch.bmm(a, encoder_outputs) - Returns: - Tuple[Tensor, Tensor]: The encoder output and the hidden state of the - encoder. - - """ - - output, hidden = self.rnn(x) - - # Get the hidden state from the forward and backward rnn. - hidden_state = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)) - - # Apply fully connected layer and tanh activation. - hidden_state = self.fc(hidden_state) - - return output, hidden_state - - -class Attention(nn.Module): - - def __init__(self, encoder_dim: int, decoder_dim: int) -> None: - super().__init__() - self.atten = nn.Linear(in_features=2*encoder_dim + decoder_dim, out_features=decoder_dim) - self.value = nn.Linear(in_features=decoder_dim, out_features=1, bias=False) - - def forward(self, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: - """Short summary. - - Args: - hidden_state (Tensor): Description of parameter `h`. - encoder_outputs (Tensor): Description of parameter `enc_out`. - - Shape: - - x: :math:`(T, N, E)`. - - output[0]: :math:`(T, N, 2 * E)`. - - output[1]: :math:`(T, N, D)`. - - where T is the sequence length, N is the batch size, E is the - embedding/encoder dimension, and D is the decoder dimension. - - Returns: - Tensor: Description of returned object. - - """ - t, b = enc_out.shape[:2] - #repeat decoder hidden state src_len times - hidden_state = hidden_state.unsqueeze(1).repeat(1, t, 1) - - encoder_outputs = rearrange(encoder_outputs, "t b e2 -> b t e2") - - # Calculate the energy between the decoders previous hidden state and the - # encoders hidden states. - energy = torch.tanh(self.attn(torch.cat((hidden_state, encoder_outputs), dim = 2))) - - attention = self.value(energy).squeeze(2) - - # Apply softmax on the attention to squeeze it between 0 and 1. - attention = F.softmax(attention, dim=1) - - return attention - - -class Decoder(nn.Module): - - def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, dropout_rate: float = 0.1) -> None: - super().__init__() - self.output_dim = output_dim - self.embedding = nn.Embedding(output_dim, embedding_dim) - self.attention = Attention(encoder_dim, decoder_dim) - self.rnn = nn.GRU(input_size=2*encoder_dim + embedding_dim, hidden_size=decoder_dim) - - self.head = nn.Linear(in_features=2*encoder_dim+embedding_dim+decoder_dim, out_features=output_dim) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward(self, trg: Tensor, hidden_state: Tensor, encoder_outputs: Tensor) -> Tensor: - #input = [batch size] - #hidden = [batch size, dec hid dim] - #encoder_outputs = [src len, batch size, enc hid dim * 2] - trg = trg.unsqueeze(0) - trg_embedded = self.dropout(self.embedding(trg)) - - a = self.attention(hidden_state, encoder_outputs) - - weighted = torch.bmm(a, encoder_outputs) - - # Permutate the tensor. - weighted = rearrange(weighted, "b a e2 -> a b e2") - - rnn_input = torch.cat((trg_embedded, weighted), dim = 2) - - output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) - - #seq len, n layers and n directions will always be 1 in this decoder, therefore: - #output = [1, batch size, dec hid dim] - #hidden = [1, batch size, dec hid dim] - #this also means that output == hidden - assert (output == hidden).all() - - trg_embedded = trg_embedded.squeeze(0) - output = output.squeeze(0) - weighted = weighted.squeeze(0) - - logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim = 1)) - - #prediction = [batch size, output dim] - - return logits, hidden.squeeze(0) - - -class NeuralMachineReader(nn.Module): - - def __init__(self, embedding_dim: int, encoder_dim: int, decoder_dim: int, output_dim: int, backbone: Optional[str] = None, - backbone_args: Optional[Dict] = None, patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), dropout_rate: float = 0.1, teacher_forcing_ratio: float = 0.5) -> None: - super().__init__() - self.patch_size = patch_size - self.stride = stride - self.sliding_window = self._configure_sliding_window() - - self.backbone = - self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) - self.decoder = Decoder(embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate) - self.teacher_forcing_ratio = teacher_forcing_ratio - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor, trg: Tensor) -> Tensor: - #x = [batch size, height, width] - #trg = [trg len, batch size] - - # Create image patches with a sliding window kernel. - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - - x = self.backbone(x) - x = rearrange(x, "(b t) h -> t b h", b=b, t=t) +# # Permutate the tensor. +# weighted = rearrange(weighted, "b a e2 -> a b e2") + +# rnn_input = torch.cat((trg_embedded, weighted), dim=2) + +# output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) + +# # seq len, n layers and n directions will always be 1 in this decoder, therefore: +# # output = [1, batch size, dec hid dim] +# # hidden = [1, batch size, dec hid dim] +# # this also means that output == hidden +# assert (output == hidden).all() + +# trg_embedded = trg_embedded.squeeze(0) +# output = output.squeeze(0) +# weighted = weighted.squeeze(0) + +# logits = self.fc_out(torch.cat((output, weighted, trg_embedded), dim=1)) + +# # prediction = [batch size, output dim] + +# return logits, hidden.squeeze(0) + + +# class NeuralMachineReader(nn.Module): +# def __init__( +# self, +# embedding_dim: int, +# encoder_dim: int, +# decoder_dim: int, +# output_dim: int, +# backbone: Optional[str] = None, +# backbone_args: Optional[Dict] = None, +# adaptive_pool_dim: Tuple = (None, 1), +# dropout_rate: float = 0.1, +# teacher_forcing_ratio: float = 0.5, +# ) -> None: +# super().__init__() + +# self.backbone = configure_backbone(backbone, backbone_args) +# self.adaptive_pool = nn.AdaptiveAvgPool2d((adaptive_pool_dim)) + +# self.encoder = Encoder(embedding_dim, encoder_dim, decoder_dim, dropout_rate) +# self.decoder = Decoder( +# embedding_dim, encoder_dim, decoder_dim, output_dim, dropout_rate +# ) +# self.teacher_forcing_ratio = teacher_forcing_ratio + +# def extract_image_features(self, x: Tensor) -> Tensor: +# x = self.backbone(x) +# x = rearrange(x, "b c h w -> b w c h") +# x = self.adaptive_pool(x) +# x = x.squeeze(3) + +# def forward(self, x: Tensor, trg: Tensor) -> Tensor: +# # x = [batch size, height, width] +# # trg = [trg len, batch size] +# z = self.extract_image_features(x) diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py index b031128..e9d216f 100644 --- a/src/text_recognizer/networks/stn.py +++ b/src/text_recognizer/networks/stn.py @@ -13,7 +13,7 @@ class SpatialTransformerNetwork(nn.Module): Network that learns how to perform spatial transformations on the input image in order to enhance the geometric invariance of the model. - # TODO: add arguements to make it more general. + # TODO: add arguments to make it more general. """ diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index b31e640..e2d7955 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -24,7 +24,7 @@ def sliding_window( """ unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the slidning window, unsqueeze as the channel dimesion is lost. + # Preform the sliding window, unsqueeze as the channel dimesion is lost. c = images.shape[1] patches = unfold(images) patches = rearrange( diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py deleted file mode 100644 index f227954..0000000 --- a/src/text_recognizer/networks/vision_transformer.py +++ /dev/null @@ -1,159 +0,0 @@ -"""VisionTransformer module. - -Splits each image into patches and feeds them to a transformer. - -""" - -from typing import Dict, Optional, Tuple, Type - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import configure_backbone - - -class VisionTransformer(nn.Module): - """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - max_len: int, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - mlp_dim: Optional[int] = None, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - activation: str = "gelu", - backbone: Optional[str] = None, - backbone_args: Optional[Dict] = None, - ) -> None: - super().__init__() - - self.patch_size = patch_size - self.stride = stride - self.trg_pad_index = trg_pad_index - self.slidning_window = self._configure_sliding_window() - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) - self.mlp_dim = mlp_dim - - self.use_backbone = False - if backbone is None: - self.linear_projection = nn.Linear( - self.patch_size[0] * self.patch_size[1], hidden_dim - ) - else: - self.backbone = configure_backbone(backbone, backbone_args) - if mlp_dim: - self.mlp = nn.Linear(mlp_dim, hidden_dim) - self.use_backbone = True - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def _backbone(self, x: Tensor) -> Tensor: - b, t = x.shape[:2] - if self.use_backbone: - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) - if self.mlp_dim: - x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) - x = self.mlp(x) - else: - x = rearrange(x, "(b t) h -> b t h", b=b, t=t) - else: - x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) - x = self.linear_projection(x) - return x - - def preprocess_input(self, src: Tensor) -> Tensor: - """Encodes src with a backbone network and a positional encoding. - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimenstion is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - src = self.slidning_window(src) # .squeeze(-2) - src = self._backbone(src) - src = self.position_encoding(src) - return src - - def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - trg_mask = self._create_trg_mask(trg) - trg = self.character_embedding(trg.long()) - trg = self.position_encoding(trg) - return trg, trg_mask - - def forward(self, x: Tensor, trg: Tensor) -> Tensor: - """Forward pass with vision transfomer.""" - src = self.preprocess_input(x) - trg, trg_mask = self.preprocess_target(trg) - out = self.transformer(src, trg, trg_mask=trg_mask) - logits = self.head(out) - return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index aa79c12..28f3380 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -2,7 +2,7 @@ from functools import partial from typing import Callable, Dict, List, Optional, Type, Union -from einops.layers.torch import Rearrange, Reduce +from einops.layers.torch import Reduce import numpy as np import torch from torch import nn |