summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/dataset.py4
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py3
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py2
-rw-r--r--src/text_recognizer/datasets/transforms.py9
-rw-r--r--src/text_recognizer/datasets/util.py15
-rw-r--r--src/text_recognizer/models/__init__.py2
-rw-r--r--src/text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--src/text_recognizer/models/transformer_model.py4
-rw-r--r--src/text_recognizer/networks/__init__.py6
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py47
-rw-r--r--src/text_recognizer/networks/metrics.py25
-rw-r--r--src/text_recognizer/networks/residual_network.py4
-rw-r--r--src/text_recognizer/networks/transformer/__init__.py2
-rw-r--r--src/text_recognizer/networks/transformer/transformer.py26
-rw-r--r--src/text_recognizer/networks/util.py1
-rw-r--r--src/text_recognizer/networks/vit.py150
-rw-r--r--src/text_recognizer/networks/wide_resnet.py13
17 files changed, 388 insertions, 45 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 95063bc..e794605 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -22,6 +22,7 @@ class Dataset(data.Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Initialization of Dataset class.
@@ -33,6 +34,7 @@ class Dataset(data.Dataset):
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): Only use lower case letters. Defaults to False.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
@@ -47,7 +49,7 @@ class Dataset(data.Dataset):
self.subsample_fraction = subsample_fraction
self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
)
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index eddf341..1992446 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Set attributes and loads the dataset.
@@ -60,6 +61,7 @@ class EmnistLinesDataset(Dataset):
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
"""
self.pad_token = "_" if pad_token is None else pad_token
@@ -72,6 +74,7 @@ class EmnistLinesDataset(Dataset):
init_token=init_token,
pad_token=self.pad_token,
eos_token=eos_token,
+ lower=lower,
)
# Extract dataset information.
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 5ae142c..1cb84bd 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -35,6 +35,7 @@ class IamLinesDataset(Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
self.pad_token = "_" if pad_token is None else pad_token
@@ -46,6 +47,7 @@ class IamLinesDataset(Dataset):
init_token=init_token,
pad_token=pad_token,
eos_token=eos_token,
+ lower=lower,
)
@property
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 016ec80..8956b01 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -93,3 +93,12 @@ class Squeeze:
def __call__(self, x: Tensor) -> Tensor:
"""Removes first dim."""
return x.squeeze(0)
+
+
+class ToLower:
+ """Converts target to lower case."""
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Corrects index value in target tensor."""
+ device = target.device
+ return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index bf5e772..da87756 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,17 +1,14 @@
"""Util functions for datasets."""
import hashlib
-import importlib
import json
import os
from pathlib import Path
import string
-from typing import Callable, Dict, List, Optional, Type, Union
-from urllib.request import urlopen, urlretrieve
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
-import cv2
from loguru import logger
import numpy as np
-from PIL import Image
import torch
from torch import Tensor
from torchvision.datasets import EMNIST
@@ -50,11 +47,13 @@ class EmnistMapper:
pad_token: str,
init_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
+ self.lower = lower
self.essentials = self._load_emnist_essentials()
# Load dataset information.
@@ -120,6 +119,12 @@ class EmnistMapper:
def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
extra_symbols = [
" ",
"!",
diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py
index a645cec..eb5dbce 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/src/text_recognizer/models/__init__.py
@@ -2,12 +2,14 @@
from .base import Model
from .character_model import CharacterModel
from .crnn_model import CRNNModel
+from .ctc_transformer_model import CTCTransformerModel
from .segmentation_model import SegmentationModel
from .transformer_model import TransformerModel
__all__ = [
"CharacterModel",
"CRNNModel",
+ "CTCTransformerModel",
"Model",
"SegmentationModel",
"TransformerModel",
diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py
new file mode 100644
index 0000000..25925f2
--- /dev/null
+++ b/src/text_recognizer/models/ctc_transformer_model.py
@@ -0,0 +1,120 @@
+"""Defines the CTC Transformer Model class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CTCTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.lower = dataset_args["args"]["lower"]
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
+
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 53]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=53,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py
index a912122..12e497f 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/src/text_recognizer/models/transformer_model.py
@@ -50,13 +50,15 @@ class TransformerModel(Model):
self.init_token = dataset_args["args"]["init_token"]
self.pad_token = dataset_args["args"]["pad_token"]
self.eos_token = dataset_args["args"]["eos_token"]
- self.max_len = 120
+ self.lower = dataset_args["args"]["lower"]
+ self.max_len = 100
if self._mapper is None:
self._mapper = EmnistMapper(
init_token=self.init_token,
pad_token=self.pad_token,
eos_token=self.eos_token,
+ lower=self.lower,
)
self.tensor_transform = ToTensor()
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index f958672..2b624bb 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -3,19 +3,18 @@ from .cnn_transformer import CNNTransformer
from .crnn import ConvolutionalRecurrentNetwork
from .ctc import greedy_decoder
from .densenet import DenseNet
-from .fcn import FCN
from .lenet import LeNet
-from .metrics import accuracy, accuracy_ignore_pad, cer, wer
+from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
+from .vit import ViT
from .wide_resnet import WideResidualNetwork
__all__ = [
"accuracy",
- "accuracy_ignore_pad",
"cer",
"CNNTransformer",
"ConvolutionalRecurrentNetwork",
@@ -29,6 +28,7 @@ __all__ = [
"sliding_window",
"UNet",
"Transformer",
+ "ViT",
"wer",
"WideResidualNetwork",
]
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index b2b74b3..caa73e3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,12 +1,13 @@
"""A CNN-Transformer for image to text recognition."""
from typing import Dict, Optional, Tuple
-from einops import rearrange
+from einops import rearrange, repeat
import torch
from torch import nn
from torch import Tensor
from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
from text_recognizer.networks.util import configure_backbone
@@ -24,15 +25,21 @@ class CNNTransformer(nn.Module):
expansion_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ max_len: int,
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
self.adaptive_pool = (
nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
@@ -48,7 +55,11 @@ class CNNTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+ self.head = nn.Sequential(
+ # nn.Linear(hidden_dim, hidden_dim * 2),
+ # activation_function(activation),
+ nn.Linear(hidden_dim, vocab_size),
+ )
def _create_trg_mask(self, trg: Tensor) -> Tensor:
# Move this outside the transformer.
@@ -96,7 +107,21 @@ class CNNTransformer(nn.Module):
else:
src = rearrange(src, "b c h w -> b (w h) c")
- src = self.position_encoding(src)
+ b, t, _ = src.shape
+
+ # Insert sos and eos token.
+ # sos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 2]).long().to(src.device)
+ # )
+ # eos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 1]).long().to(src.device)
+ # )
+
+ # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
+ # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
+ # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
+ # src = torch.cat((sos_tokens, src), dim=1)
+ src += self.src_position_embedding[:, :t]
return src
@@ -111,20 +136,22 @@ class CNNTransformer(nn.Module):
"""
trg = self.character_embedding(trg.long())
- trg = self.position_encoding(trg)
+ trg = self.trg_position_encoding(trg)
return trg
- def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
"""Takes images features from the backbone and decodes them with the transformer."""
trg_mask = self._create_trg_mask(trg)
trg = self.target_embedding(trg)
- out = self.transformer(h, trg, trg_mask=trg_mask)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
logits = self.head(out)
return logits
def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
- logits = self.decode_image_features(h, trg)
+ image_features = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
return logits
diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py
index af9adb5..ffad792 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/src/text_recognizer/networks/metrics.py
@@ -6,28 +6,13 @@ from torch import Tensor
from text_recognizer.networks import greedy_decoder
-def accuracy_ignore_pad(
- output: Tensor,
- target: Tensor,
- pad_index: int = 79,
- eos_index: int = 81,
- seq_len: int = 97,
-) -> float:
- """Sets all predictions after eos to pad."""
- start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1)
- end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len)
- for start, stop in zip(start_indices, end_indices):
- output[start + 1 : stop] = pad_index
-
- return accuracy(output, target)
-
-
-def accuracy(outputs: Tensor, labels: Tensor,) -> float:
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
"""Computes the accuracy.
Args:
outputs (Tensor): The output from the network.
labels (Tensor): Ground truth labels.
+ pad_index (int): Padding index.
Returns:
float: The accuracy for the batch.
@@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float:
_, predicted = torch.max(outputs, dim=-1)
+ # Mask out the pad tokens
+ mask = labels != pad_index
+
+ predicted *= mask
+ labels *= mask
+
acc = (predicted == labels).sum().float() / labels.shape[0]
acc = acc.item()
return acc
diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py
index e397224..c33f419 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/src/text_recognizer/networks/residual_network.py
@@ -221,8 +221,8 @@ class ResidualNetworkEncoder(nn.Module):
nn.Conv2d(
in_channels=in_channels,
out_channels=self.block_sizes[0],
- kernel_size=3,
- stride=1,
+ kernel_size=7,
+ stride=2,
padding=1,
bias=False,
),
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py
index 020a917..9febc88 100644
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ b/src/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
"""Transformer modules."""
from .positional_encoding import PositionalEncoding
-from .transformer import Decoder, Encoder, Transformer
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py
index c6e943e..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/src/text_recognizer/networks/transformer/transformer.py
@@ -6,11 +6,25 @@ import numpy as np
import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
from text_recognizer.networks.transformer.attention import MultiHeadAttention
from text_recognizer.networks.util import activation_function
+class GEGLU(nn.Module):
+ """GLU activation for improving feedforward activations."""
+
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward propagation."""
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
@@ -36,9 +50,17 @@ class _ConvolutionalLayer(nn.Module):
activation: str = "relu",
) -> None:
super().__init__()
+
+ in_projection = (
+ nn.Sequential(
+ nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+ )
+ if activation != "glu"
+ else GEGLU(hidden_dim, expansion_dim)
+ )
+
self.layer = nn.Sequential(
- nn.Linear(in_features=hidden_dim, out_features=expansion_dim),
- activation_function(activation),
+ in_projection,
nn.Dropout(p=dropout_rate),
nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
)
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
index e2d7955..711a952 100644
--- a/src/text_recognizer/networks/util.py
+++ b/src/text_recognizer/networks/util.py
@@ -39,6 +39,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
[
["elu", nn.ELU(inplace=True)],
["gelu", nn.GELU()],
+ ["glu", nn.GLU()],
["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py
new file mode 100644
index 0000000..efb3701
--- /dev/null
+++ b/src/text_recognizer/networks/vit.py
@@ -0,0 +1,150 @@
+"""A Vision Transformer.
+
+Inspired by:
+https://openreview.net/pdf?id=YicbFdNTTy
+
+"""
+from typing import Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import Transformer
+
+
+class ViT(nn.Module):
+ """Transfomer for image to sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ expansion_dim: int,
+ patch_dim: Tuple[int, int],
+ image_size: Tuple[int, int],
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ self.trg_pad_index = trg_pad_index
+ self.patch_dim = patch_dim
+ self.num_patches = image_size[-1] // self.patch_dim[1]
+
+ # Encoder
+ self.patch_to_embedding = nn.Linear(
+ self.patch_dim[0] * self.patch_dim[1], hidden_dim
+ )
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.dropout = nn.Dropout(dropout_rate)
+ self._init()
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _init(self) -> None:
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+ # nn.init.normal_(self.pos_embedding.weight, std=0.02)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tensor:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+
+ patches = rearrange(
+ src,
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+ p1=self.patch_dim[0],
+ p2=self.patch_dim[1],
+ )
+
+ # From patches to encoded sequence.
+ x = self.patch_to_embedding(patches)
+ b, n, _ = x.shape
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x += self.pos_embedding[:, : (n + 1)]
+ x = self.dropout(x)
+
+ return x
+
+ def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ _, n = trg.shape
+ trg = self.character_embedding(trg.long())
+ trg += self.pos_embedding[:, :n]
+ return trg
+
+ def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.extract_image_features(x)
+ logits = self.decode_image_features(h, trg)
+ return logits
diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py
index 28f3380..b767778 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/src/text_recognizer/networks/wide_resnet.py
@@ -113,6 +113,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate: float = 0.0,
num_layers: int = 3,
block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
activation: str = "relu",
use_decoder: bool = True,
) -> None:
@@ -127,6 +128,7 @@ class WideResidualNetwork(nn.Module):
dropout_rate (float): The dropout rate. Defaults to 0.0.
num_layers (int): Number of layers of blocks. Defaults to 3.
block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
activation (str): Name of the activation to use. Defaults to "relu".
use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
latent vector. Defaults to True.
@@ -149,9 +151,14 @@ class WideResidualNetwork(nn.Module):
self.dropout_rate = dropout_rate
self.activation = activation_function(activation)
- self.num_stages = [self.in_planes] + [
- self.in_planes * 2 ** n * self.width_factor for n in range(self.num_layers)
- ]
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
self.strides = [1] + [2] * (self.num_layers - 1)