From ff9a21d333f11a42e67c1963ed67de9c0fda87c9 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 7 Jan 2021 20:10:54 +0100 Subject: Minor updates. --- src/text_recognizer/networks/__init__.py | 6 +- src/text_recognizer/networks/cnn_transformer.py | 47 +++++-- src/text_recognizer/networks/metrics.py | 25 ++-- src/text_recognizer/networks/residual_network.py | 4 +- .../networks/transformer/__init__.py | 2 +- .../networks/transformer/transformer.py | 26 +++- src/text_recognizer/networks/util.py | 1 + src/text_recognizer/networks/vit.py | 150 +++++++++++++++++++++ src/text_recognizer/networks/wide_resnet.py | 13 +- 9 files changed, 236 insertions(+), 38 deletions(-) create mode 100644 src/text_recognizer/networks/vit.py (limited to 'src/text_recognizer/networks') diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index f958672..2b624bb 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -3,19 +3,18 @@ from .cnn_transformer import CNNTransformer from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder from .densenet import DenseNet -from .fcn import FCN from .lenet import LeNet -from .metrics import accuracy, accuracy_ignore_pad, cer, wer +from .metrics import accuracy, cer, wer from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .transformer import Transformer from .unet import UNet from .util import sliding_window +from .vit import ViT from .wide_resnet import WideResidualNetwork __all__ = [ "accuracy", - "accuracy_ignore_pad", "cer", "CNNTransformer", "ConvolutionalRecurrentNetwork", @@ -29,6 +28,7 @@ __all__ = [ "sliding_window", "UNet", "Transformer", + "ViT", "wer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index b2b74b3..caa73e3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,12 +1,13 @@ """A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple -from einops import rearrange +from einops import rearrange, repeat import torch from torch import nn from torch import Tensor from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function from text_recognizer.networks.util import configure_backbone @@ -24,15 +25,21 @@ class CNNTransformer(nn.Module): expansion_dim: int, dropout_rate: float, trg_pad_index: int, + max_len: int, backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + nn.init.normal_(self.character_embedding.weight, std=0.02) self.adaptive_pool = ( nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -48,7 +55,11 @@ class CNNTransformer(nn.Module): activation, ) - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + self.head = nn.Sequential( + # nn.Linear(hidden_dim, hidden_dim * 2), + # activation_function(activation), + nn.Linear(hidden_dim, vocab_size), + ) def _create_trg_mask(self, trg: Tensor) -> Tensor: # Move this outside the transformer. @@ -96,7 +107,21 @@ class CNNTransformer(nn.Module): else: src = rearrange(src, "b c h w -> b (w h) c") - src = self.position_encoding(src) + b, t, _ = src.shape + + # Insert sos and eos token. + # sos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 2]).long().to(src.device) + # ) + # eos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 1]).long().to(src.device) + # ) + + # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) + # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) + # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) + # src = torch.cat((sos_tokens, src), dim=1) + src += self.src_position_embedding[:, :t] return src @@ -111,20 +136,22 @@ class CNNTransformer(nn.Module): """ trg = self.character_embedding(trg.long()) - trg = self.position_encoding(trg) + trg = self.trg_position_encoding(trg) return trg - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) + out = self.transformer(image_features, trg, trg_mask=trg_mask) logits = self.head(out) return logits def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) + image_features = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) return logits diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..ffad792 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -6,28 +6,13 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: """Computes the accuracy. Args: outputs (Tensor): The output from the network. labels (Tensor): Ground truth labels. + pad_index (int): Padding index. Returns: float: The accuracy for the batch. @@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float: _, predicted = torch.max(outputs, dim=-1) + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index e397224..c33f419 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -221,8 +221,8 @@ class ResidualNetworkEncoder(nn.Module): nn.Conv2d( in_channels=in_channels, out_channels=self.block_sizes[0], - kernel_size=3, - stride=1, + kernel_size=7, + stride=2, padding=1, bias=False, ), diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py index 020a917..9febc88 100644 --- a/src/text_recognizer/networks/transformer/__init__.py +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" from .positional_encoding import PositionalEncoding -from .transformer import Decoder, Encoder, Transformer +from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py index c6e943e..dd180c4 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -6,11 +6,25 @@ import numpy as np import torch from torch import nn from torch import Tensor +import torch.nn.functional as F from text_recognizer.networks.transformer.attention import MultiHeadAttention from text_recognizer.networks.util import activation_function +class GEGLU(nn.Module): + """GLU activation for improving feedforward activations.""" + + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + """Forward propagation.""" + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) @@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module): activation: str = "relu", ) -> None: super().__init__() + + in_projection = ( + nn.Sequential( + nn.Linear(hidden_dim, expansion_dim), activation_function(activation) + ) + if activation != "glu" + else GEGLU(hidden_dim, expansion_dim) + ) + self.layer = nn.Sequential( - nn.Linear(in_features=hidden_dim, out_features=expansion_dim), - activation_function(activation), + in_projection, nn.Dropout(p=dropout_rate), nn.Linear(in_features=expansion_dim, out_features=hidden_dim), ) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py index e2d7955..711a952 100644 --- a/src/text_recognizer/networks/util.py +++ b/src/text_recognizer/networks/util.py @@ -39,6 +39,7 @@ def activation_function(activation: str) -> Type[nn.Module]: [ ["elu", nn.ELU(inplace=True)], ["gelu", nn.GELU()], + ["glu", nn.GLU()], ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], ["none", nn.Identity()], ["relu", nn.ReLU(inplace=True)], diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py new file mode 100644 index 0000000..efb3701 --- /dev/null +++ b/src/text_recognizer/networks/vit.py @@ -0,0 +1,150 @@ +"""A Vision Transformer. + +Inspired by: +https://openreview.net/pdf?id=YicbFdNTTy + +""" +from typing import Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import Transformer + + +class ViT(nn.Module): + """Transfomer for image to sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + expansion_dim: int, + patch_dim: Tuple[int, int], + image_size: Tuple[int, int], + dropout_rate: float, + trg_pad_index: int, + max_len: int, + activation: str = "gelu", + ) -> None: + super().__init__() + + self.trg_pad_index = trg_pad_index + self.patch_dim = patch_dim + self.num_patches = image_size[-1] // self.patch_dim[1] + + # Encoder + self.patch_to_embedding = nn.Linear( + self.patch_dim[0] * self.patch_dim[1], hidden_dim + ) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.dropout = nn.Dropout(dropout_rate) + self._init() + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _init(self) -> None: + nn.init.normal_(self.character_embedding.weight, std=0.02) + # nn.init.normal_(self.pos_embedding.weight, std=0.02) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + + patches = rearrange( + src, + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=self.patch_dim[0], + p2=self.patch_dim[1], + ) + + # From patches to encoded sequence. + x = self.patch_to_embedding(patches) + b, n, _ = x.shape + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + return x + + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + _, n = trg.shape + trg = self.character_embedding(trg.long()) + trg += self.pos_embedding[:, :n] + return trg + + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index 28f3380..b767778 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module): dropout_rate: float = 0.0, num_layers: int = 3, block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, activation: str = "relu", use_decoder: bool = True, ) -> None: @@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module): dropout_rate (float): The dropout rate. Defaults to 0.0. num_layers (int): Number of layers of blocks. Defaults to 3. block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. activation (str): Name of the activation to use. Defaults to "relu". use_decoder (bool): If True, the network output character predictions, if False, the network outputs a latent vector. Defaults to True. @@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module): self.dropout_rate = dropout_rate self.activation = activation_function(activation) - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers) - ] + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) self.strides = [1] + [2] * (self.num_layers - 1) -- cgit v1.2.3-70-g09d2