summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py6
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py47
-rw-r--r--src/text_recognizer/networks/metrics.py25
-rw-r--r--src/text_recognizer/networks/residual_network.py4
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py26
-rw-r--r--src/text_recognizer/networks/util.py1
-rw-r--r--src/text_recognizer/networks/vit.py150
-rw-r--r--src/text_recognizer/networks/wide_resnet.py13
9 files changed, 236 insertions, 38 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index f958672..2b624bb 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -3,19 +3,18 @@ from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
-from .fcn import FCN
from .lenet import LeNet
-from .metrics import accuracy, accuracy_ignore_pad, cer, wer
+from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
+from .vit import ViT
from .wide_resnet import WideResidualNetwork
__all__ = [
"accuracy",
- "accuracy_ignore_pad",
"cer",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
@@ -29,6 +28,7 @@ __all__ = [
"sliding_window",
"UNet",
"Transformer",
+ "ViT",
"wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index b2b74b3..caa73e3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,12 +1,13 @@
"""A CNN-Transformer for image to text recognition."""
from typing import Dict, Optional, Tuple
-from einops import rearrange
+from einops import rearrange, repeat
import torch
from torch import nn
from torch import Tensor
from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
from text_recognizer.networks.util import configure_backbone
@@ -24,15 +25,21 @@ class CNNTransformer(nn.Module):
expansion_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ max_len: int,
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
self.adaptive_pool = (
nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
@@ -48,7 +55,11 @@ class CNNTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+ self.head = nn.Sequential(
+ # nn.Linear(hidden_dim, hidden_dim * 2),
+ # activation_function(activation),
+ nn.Linear(hidden_dim, vocab_size),
+ )
def _create_trg_mask(self, trg: Tensor) -> Tensor:
# Move this outside the transformer.
@@ -96,7 +107,21 @@ class CNNTransformer(nn.Module):
else:
src = rearrange(src, "b c h w -> b (w h) c")
- src = self.position_encoding(src)
+ b, t, _ = src.shape
+
+ # Insert sos and eos token.
+ # sos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 2]).long().to(src.device)
+ # )
+ # eos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 1]).long().to(src.device)
+ # )
+
+ # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
+ # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
+ # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
+ # src = torch.cat((sos_tokens, src), dim=1)
+ src += self.src_position_embedding[:, :t]
return src
@@ -111,20 +136,22 @@ class CNNTransformer(nn.Module):
"""
trg = self.character_embedding(trg.long())
- trg = self.position_encoding(trg)
+ trg = self.trg_position_encoding(trg)
return trg
- def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
"""Takes images features from the backbone and decodes them with the transformer."""
trg_mask = self._create_trg_mask(trg)
trg = self.target_embedding(trg)
- out = self.transformer(h, trg, trg_mask=trg_mask)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
logits = self.head(out)
return logits
def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
- logits = self.decode_image_features(h, trg)
+ image_features = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
return logits
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index af9adb5..ffad792 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -6,28 +6,13 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy_ignore_pad(
- output: Tensor,
- target: Tensor,
- pad_index: int = 79,
- eos_index: int = 81,
- seq_len: int = 97,
-) -> float:
- """Sets all predictions after eos to pad."""
- start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
- end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
- for start, stop in zip(start_indices, end_indices):
- output[start + 1 : stop] = pad_index
-
- return accuracy(output, target)
-
-
-def accuracy(outputs: Tensor, labels: Tensor,) -> float:
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
"""Computes the accuracy.
Args:
outputs (Tensor): The output from the network.
labels (Tensor): Ground truth labels.
+ pad_index (int): Padding index.
Returns:
float: The accuracy for the batch.
@@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float:
_, predicted = torch.max(outputs, dim=-1)
+ # Mask out the pad tokens
+ mask = labels != pad_index
+
+ predicted *= mask
+ labels *= mask
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index e397224..c33f419 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -221,8 +221,8 @@ class ResidualNetworkEncoder(nn.Module):
nn.Conv2d(
in_channels=in_channels,
out_channels=self.block_sizes[0],
- kernel_size=3,
- stride=1,
+ kernel_size=7,
+ stride=2,
padding=1,
bias=False,
),
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
index 020a917..9febc88 100644
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
"""Transformer modules."""
from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, Transformer
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index c6e943e..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -6,11 +6,25 @@ import numpy as np
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
from text_recognizer.networks.transformer.attention import MultiHeadAttention
from text_recognizer.networks.util import activation_function
+class GEGLU(nn.Module):
+ """GLU activation for improving feedforward activations."""
+
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward propagation."""
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
@@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module):
activation: str = "relu",
) -> None:
super().__init__()
+
+ in_projection = (
+ nn.Sequential(
+ nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+ )
+ if activation != "glu"
+ else GEGLU(hidden_dim, expansion_dim)
+ )
+
self.layer = nn.Sequential(
- nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
- activation_function(activation),
+ in_projection,
nn.Dropout(p=dropout_rate),
nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index e2d7955..711a952 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -39,6 +39,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
[
["elu", nn.ELU(inplace=True)],
["gelu", nn.GELU()],
+ ["glu", nn.GLU()],
["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py
new file mode 100644
index 0000000..efb3701
--- /dev/null
+++ b/src/text_recognizer/networks/vit.py
@@ -0,0 +1,150 @@
+"""A Vision Transformer.
+
+Inspired by:
+https://openreview.net/pdf?id=YicbFdNTTy
+
+"""
+from typing import Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import Transformer
+
+
+class ViT(nn.Module):
+ """Transfomer for image to sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ expansion_dim: int,
+ patch_dim: Tuple[int, int],
+ image_size: Tuple[int, int],
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ self.trg_pad_index = trg_pad_index
+ self.patch_dim = patch_dim
+ self.num_patches = image_size[-1] // self.patch_dim[1]
+
+ # Encoder
+ self.patch_to_embedding = nn.Linear(
+ self.patch_dim[0] * self.patch_dim[1], hidden_dim
+ )
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.dropout = nn.Dropout(dropout_rate)
+ self._init()
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _init(self) -> None:
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+ # nn.init.normal_(self.pos_embedding.weight, std=0.02)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tensor:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+
+ patches = rearrange(
+ src,
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+ p1=self.patch_dim[0],
+ p2=self.patch_dim[1],
+ )
+
+ # From patches to encoded sequence.
+ x = self.patch_to_embedding(patches)
+ b, n, _ = x.shape
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x += self.pos_embedding[:, : (n + 1)]
+ x = self.dropout(x)
+
+ return x
+
+ def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ _, n = trg.shape
+ trg = self.character_embedding(trg.long())
+ trg += self.pos_embedding[:, :n]
+ return trg
+
+ def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.extract_image_features(x)
+ logits = self.decode_image_features(h, trg)
+ return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 28f3380..b767778 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate: float = 0.0,
num_layers: int = 3,
block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
activation: str = "relu",
use_decoder: bool = True,
) -> None:
@@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate (float): The dropout rate. Defaults to 0.0.
num_layers (int): Number of layers of blocks. Defaults to 3.
block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
activation (str): Name of the activation to use. Defaults to "relu".
use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
latent vector. Defaults to True.
@@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module):
self.dropout_rate = dropout_rate
self.activation = activation_function(activation)
- self.num_stages = [self.in_planes] + [
- self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
- ]
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
self.strides = [1] + [2] * (self.num_layers - 1)