diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
commit | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch) | |
tree | afee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer | |
parent | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff) |
Minor updates.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/models/ctc_transformer_model.py | 120 | ||||
-rw-r--r-- | src/text_recognizer/models/transformer_model.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/__init__.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 47 | ||||
-rw-r--r-- | src/text_recognizer/networks/metrics.py | 25 | ||||
-rw-r--r-- | src/text_recognizer/networks/residual_network.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/transformer/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/transformer/transformer.py | 26 | ||||
-rw-r--r-- | src/text_recognizer/networks/util.py | 1 | ||||
-rw-r--r-- | src/text_recognizer/networks/vit.py | 150 | ||||
-rw-r--r-- | src/text_recognizer/networks/wide_resnet.py | 13 |
17 files changed, 388 insertions, 45 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 95063bc..e794605 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -22,6 +22,7 @@ class Dataset(data.Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Initialization of Dataset class. @@ -33,6 +34,7 @@ class Dataset(data.Dataset): init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): Only use lower case letters. Defaults to False. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). @@ -47,7 +49,7 @@ class Dataset(data.Dataset): self.subsample_fraction = subsample_fraction self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token + init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index eddf341..1992446 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Set attributes and loads the dataset. @@ -60,6 +61,7 @@ class EmnistLinesDataset(Dataset): init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. """ self.pad_token = "_" if pad_token is None else pad_token @@ -72,6 +74,7 @@ class EmnistLinesDataset(Dataset): init_token=init_token, pad_token=self.pad_token, eos_token=eos_token, + lower=lower, ) # Extract dataset information. diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 5ae142c..1cb84bd 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -35,6 +35,7 @@ class IamLinesDataset(Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: self.pad_token = "_" if pad_token is None else pad_token @@ -46,6 +47,7 @@ class IamLinesDataset(Dataset): init_token=init_token, pad_token=pad_token, eos_token=eos_token, + lower=lower, ) @property diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 016ec80..8956b01 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -93,3 +93,12 @@ class Squeeze: def __call__(self, x: Tensor) -> Tensor: """Removes first dim.""" return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index bf5e772..da87756 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,17 +1,14 @@ """Util functions for datasets.""" import hashlib -import importlib import json import os from pathlib import Path import string -from typing import Callable, Dict, List, Optional, Type, Union -from urllib.request import urlopen, urlretrieve +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve -import cv2 from loguru import logger import numpy as np -from PIL import Image import torch from torch import Tensor from torchvision.datasets import EMNIST @@ -50,11 +47,13 @@ class EmnistMapper: pad_token: str, init_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token + self.lower = lower self.essentials = self._load_emnist_essentials() # Load dataset information. @@ -120,6 +119,12 @@ class EmnistMapper: def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + extra_symbols = [ " ", "!", diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index a645cec..eb5dbce 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -2,12 +2,14 @@ from .base import Model from .character_model import CharacterModel from .crnn_model import CRNNModel +from .ctc_transformer_model import CTCTransformerModel from .segmentation_model import SegmentationModel from .transformer_model import TransformerModel __all__ = [ "CharacterModel", "CRNNModel", + "CTCTransformerModel", "Model", "SegmentationModel", "TransformerModel", diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py new file mode 100644 index 0000000..25925f2 --- /dev/null +++ b/src/text_recognizer/models/ctc_transformer_model.py @@ -0,0 +1,120 @@ +"""Defines the CTC Transformer Model class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CTCTransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + self.lower = dataset_args["args"]["lower"] + + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) + + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 53] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=53, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index a912122..12e497f 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -50,13 +50,15 @@ class TransformerModel(Model): self.init_token = dataset_args["args"]["init_token"] self.pad_token = dataset_args["args"]["pad_token"] self.eos_token = dataset_args["args"]["eos_token"] - self.max_len = 120 + self.lower = dataset_args["args"]["lower"] + self.max_len = 100 if self._mapper is None: self._mapper = EmnistMapper( init_token=self.init_token, pad_token=self.pad_token, eos_token=self.eos_token, + lower=self.lower, ) self.tensor_transform = ToTensor() diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index f958672..2b624bb 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -3,19 +3,18 @@ from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet -from .fcn import FCN from .lenet import LeNet -from .metrics import accuracy, accuracy_ignore_pad, cer, wer +from .metrics import accuracy, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .transformer import Transformer from .unet import UNet from .util import sliding_window +from .vit import ViT from .wide_resnet import WideResidualNetwork __all__ = [ "accuracy", - "accuracy_ignore_pad", "cer", "CNNTransformer", "ConvolutionalRecurrentNetwork", @@ -29,6 +28,7 @@ __all__ = [ "sliding_window", "UNet", "Transformer", + "ViT", "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index b2b74b3..caa73e3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,12 +1,13 @@ """A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple -from einops import rearrange +from einops import rearrange, repeat import torch from torch import nn from torch import Tensor from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function from text_recognizer.networks.util import configure_backbone @@ -24,15 +25,21 @@ class CNNTransformer(nn.Module): expansion_dim: int, dropout_rate: float, trg_pad_index: int, + max_len: int, backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + nn.init.normal_(self.character_embedding.weight, std=0.02) self.adaptive_pool = ( nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -48,7 +55,11 @@ class CNNTransformer(nn.Module): activation, ) - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + self.head = nn.Sequential( + # nn.Linear(hidden_dim, hidden_dim * 2), + # activation_function(activation), + nn.Linear(hidden_dim, vocab_size), + ) def _create_trg_mask(self, trg: Tensor) -> Tensor: # Move this outside the transformer. @@ -96,7 +107,21 @@ class CNNTransformer(nn.Module): else: src = rearrange(src, "b c h w -> b (w h) c") - src = self.position_encoding(src) + b, t, _ = src.shape + + # Insert sos and eos token. + # sos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 2]).long().to(src.device) + # ) + # eos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 1]).long().to(src.device) + # ) + + # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) + # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) + # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) + # src = torch.cat((sos_tokens, src), dim=1) + src += self.src_position_embedding[:, :t] return src @@ -111,20 +136,22 @@ class CNNTransformer(nn.Module): """ trg = self.character_embedding(trg.long()) - trg = self.position_encoding(trg) + trg = self.trg_position_encoding(trg) return trg - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) + out = self.transformer(image_features, trg, trg_mask=trg_mask) logits = self.head(out) return logits def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) + image_features = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) return logits diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..ffad792 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -6,28 +6,13 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: """Computes the accuracy. Args: outputs (Tensor): The output from the network. labels (Tensor): Ground truth labels. + pad_index (int): Padding index. Returns: float: The accuracy for the batch. @@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float: _, predicted = torch.max(outputs, dim=-1) + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index e397224..c33f419 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -221,8 +221,8 @@ class ResidualNetworkEncoder(nn.Module): nn.Conv2d( in_channels=in_channels, out_channels=self.block_sizes[0], - kernel_size=3, - stride=1, + kernel_size=7, + stride=2, padding=1, bias=False, ), diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py index 020a917..9febc88 100644 --- a/src/text_recognizer/networks/transformer/__init__.py +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" from .positional_encoding import PositionalEncoding -from .transformer import Decoder, Encoder, Transformer +from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py index c6e943e..dd180c4 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -6,11 +6,25 @@ import numpy as np import torch from torch import nn from torch import Tensor +import torch.nn.functional as F from text_recognizer.networks.transformer.attention import MultiHeadAttention from text_recognizer.networks.util import activation_function +class GEGLU(nn.Module): + """GLU activation for improving feedforward activations.""" + + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + """Forward propagation.""" + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) @@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module): activation: str = "relu", ) -> None: super().__init__() + + in_projection = ( + nn.Sequential( + nn.Linear(hidden_dim, expansion_dim), activation_function(activation) + ) + if activation != "glu" + else GEGLU(hidden_dim, expansion_dim) + ) + self.layer = nn.Sequential( - nn.Linear(in_features=hidden_dim, out_features=expansion_dim), - activation_function(activation), + in_projection, nn.Dropout(p=dropout_rate), nn.Linear(in_features=expansion_dim, out_features=hidden_dim), ) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index e2d7955..711a952 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -39,6 +39,7 @@ def activation_function(activation: str) -> Type[nn.Module]: [ ["elu", nn.ELU(inplace=True)], ["gelu", nn.GELU()], + ["glu", nn.GLU()], ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], ["none", nn.Identity()], ["relu", nn.ReLU(inplace=True)], diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py new file mode 100644 index 0000000..efb3701 --- /dev/null +++ b/src/text_recognizer/networks/vit.py @@ -0,0 +1,150 @@ +"""A Vision Transformer. + +Inspired by: +https://openreview.net/pdf?id=YicbFdNTTy + +""" +from typing import Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import Transformer + + +class ViT(nn.Module): + """Transfomer for image to sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + expansion_dim: int, + patch_dim: Tuple[int, int], + image_size: Tuple[int, int], + dropout_rate: float, + trg_pad_index: int, + max_len: int, + activation: str = "gelu", + ) -> None: + super().__init__() + + self.trg_pad_index = trg_pad_index + self.patch_dim = patch_dim + self.num_patches = image_size[-1] // self.patch_dim[1] + + # Encoder + self.patch_to_embedding = nn.Linear( + self.patch_dim[0] * self.patch_dim[1], hidden_dim + ) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.dropout = nn.Dropout(dropout_rate) + self._init() + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _init(self) -> None: + nn.init.normal_(self.character_embedding.weight, std=0.02) + # nn.init.normal_(self.pos_embedding.weight, std=0.02) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + + patches = rearrange( + src, + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=self.patch_dim[0], + p2=self.patch_dim[1], + ) + + # From patches to encoded sequence. + x = self.patch_to_embedding(patches) + b, n, _ = x.shape + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + return x + + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + _, n = trg.shape + trg = self.character_embedding(trg.long()) + trg += self.pos_embedding[:, :n] + return trg + + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index 28f3380..b767778 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module): dropout_rate: float = 0.0, num_layers: int = 3, block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, activation: str = "relu", use_decoder: bool = True, ) -> None: @@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module): dropout_rate (float): The dropout rate. Defaults to 0.0. num_layers (int): Number of layers of blocks. Defaults to 3. block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. activation (str): Name of the activation to use. Defaults to "relu". use_decoder (bool): If True, the network output character predictions, if False, the network outputs a latent vector. Defaults to True. @@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module): self.dropout_rate = dropout_rate self.activation = activation_function(activation) - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers) - ] + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) self.strides = [1] + [2] * (self.num_layers - 1) |