summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/__init__.py1
-rw-r--r--text_recognizer/character_predictor.py29
-rw-r--r--text_recognizer/datasets/__init__.py39
-rw-r--r--text_recognizer/datasets/dataset.py152
-rw-r--r--text_recognizer/datasets/emnist_dataset.py131
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--text_recognizer/datasets/emnist_lines_dataset.py359
-rw-r--r--text_recognizer/datasets/iam_dataset.py133
-rw-r--r--text_recognizer/datasets/iam_lines_dataset.py110
-rw-r--r--text_recognizer/datasets/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--text_recognizer/datasets/sentence_generator.py81
-rw-r--r--text_recognizer/datasets/transforms.py266
-rw-r--r--text_recognizer/datasets/util.py209
-rw-r--r--text_recognizer/line_predictor.py28
-rw-r--r--text_recognizer/models/__init__.py18
-rw-r--r--text_recognizer/models/base.py455
-rw-r--r--text_recognizer/models/character_model.py88
-rw-r--r--text_recognizer/models/crnn_model.py119
-rw-r--r--text_recognizer/models/ctc_transformer_model.py120
-rw-r--r--text_recognizer/models/segmentation_model.py75
-rw-r--r--text_recognizer/models/transformer_model.py124
-rw-r--r--text_recognizer/models/vqvae_model.py80
-rw-r--r--text_recognizer/networks/__init__.py43
-rw-r--r--text_recognizer/networks/beam.py83
-rw-r--r--text_recognizer/networks/cnn.py101
-rw-r--r--text_recognizer/networks/cnn_transformer.py158
-rw-r--r--text_recognizer/networks/crnn.py110
-rw-r--r--text_recognizer/networks/ctc.py58
-rw-r--r--text_recognizer/networks/densenet.py225
-rw-r--r--text_recognizer/networks/lenet.py68
-rw-r--r--text_recognizer/networks/loss/__init__.py2
-rw-r--r--text_recognizer/networks/loss/loss.py69
-rw-r--r--text_recognizer/networks/metrics.py123
-rw-r--r--text_recognizer/networks/mlp.py73
-rw-r--r--text_recognizer/networks/residual_network.py310
-rw-r--r--text_recognizer/networks/stn.py44
-rw-r--r--text_recognizer/networks/transducer/__init__.py3
-rw-r--r--text_recognizer/networks/transducer/tds_conv.py208
-rw-r--r--text_recognizer/networks/transducer/test.py60
-rw-r--r--text_recognizer/networks/transducer/transducer.py410
-rw-r--r--text_recognizer/networks/transformer/__init__.py3
-rw-r--r--text_recognizer/networks/transformer/attention.py93
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py32
-rw-r--r--text_recognizer/networks/transformer/transformer.py264
-rw-r--r--text_recognizer/networks/unet.py255
-rw-r--r--text_recognizer/networks/util.py89
-rw-r--r--text_recognizer/networks/vit.py150
-rw-r--r--text_recognizer/networks/vq_transformer.py150
-rw-r--r--text_recognizer/networks/vqvae/__init__.py5
-rw-r--r--text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--text_recognizer/networks/vqvae/encoder.py147
-rw-r--r--text_recognizer/networks/vqvae/vector_quantizer.py119
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py74
-rw-r--r--text_recognizer/networks/wide_resnet.py221
-rw-r--r--text_recognizer/paragraph_text_recognizer.py153
-rw-r--r--text_recognizer/tests/__init__.py1
-rw-r--r--text_recognizer/tests/support/__init__.py2
-rw-r--r--text_recognizer/tests/support/create_emnist_lines_support_files.py51
-rw-r--r--text_recognizer/tests/support/create_emnist_support_files.py30
-rw-r--r--text_recognizer/tests/support/create_iam_lines_support_files.py50
-rw-r--r--text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.pngbin0 -> 2301 bytes
-rw-r--r--text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.pngbin0 -> 5424 bytes
-rw-r--r--text_recognizer/tests/support/emnist_lines/they<eos>.pngbin0 -> 1391 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.pngbin0 -> 5170 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.pngbin0 -> 3617 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.pngbin0 -> 3923 bytes
-rw-r--r--text_recognizer/tests/support/iam_paragraphs/a01-000u.jpgbin0 -> 14890 bytes
-rw-r--r--text_recognizer/tests/test_character_predictor.py31
-rw-r--r--text_recognizer/tests/test_line_predictor.py35
-rw-r--r--text_recognizer/tests/test_paragraph_text_recognizer.py37
-rw-r--r--text_recognizer/util.py52
-rw-r--r--text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt3
-rw-r--r--text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt3
-rw-r--r--text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt3
-rw-r--r--text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.ptbin0 -> 8588813 bytes
-rw-r--r--text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.ptbin0 -> 92335101 bytes
-rw-r--r--text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.ptbin0 -> 21687018 bytes
78 files changed, 7439 insertions, 0 deletions
diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py
new file mode 100644
index 0000000..3dc1f76
--- /dev/null
+++ b/text_recognizer/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.1.0"
diff --git a/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py
new file mode 100644
index 0000000..ad71289
--- /dev/null
+++ b/text_recognizer/character_predictor.py
@@ -0,0 +1,29 @@
+"""CharacterPredictor class."""
+from typing import Dict, Tuple, Type, Union
+
+import numpy as np
+from torch import nn
+
+from text_recognizer import datasets, networks
+from text_recognizer.models import CharacterModel
+from text_recognizer.util import read_image
+
+
+class CharacterPredictor:
+ """Recognizes the character in handwritten character images."""
+
+ def __init__(self, network_fn: str, dataset: str) -> None:
+ """Intializes the CharacterModel and load the pretrained weights."""
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = CharacterModel(network_fn=network_fn, dataset=dataset)
+ self.model.eval()
+ self.model.use_swa_model()
+
+ def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
+ """Predict on a single images contianing a handwritten character."""
+ if isinstance(image_or_filename, str):
+ image = read_image(image_or_filename, grayscale=True)
+ else:
+ image = image_or_filename
+ return self.model.predict_on_image(image)
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
new file mode 100644
index 0000000..a6c1c59
--- /dev/null
+++ b/text_recognizer/datasets/__init__.py
@@ -0,0 +1,39 @@
+"""Dataset modules."""
+from .emnist_dataset import EmnistDataset
+from .emnist_lines_dataset import (
+ construct_image_from_string,
+ EmnistLinesDataset,
+ get_samples_by_character,
+)
+from .iam_dataset import IamDataset
+from .iam_lines_dataset import IamLinesDataset
+from .iam_paragraphs_dataset import IamParagraphsDataset
+from .iam_preprocessor import load_metadata, Preprocessor
+from .transforms import AddTokens, Transpose
+from .util import (
+ _download_raw_dataset,
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
+
+__all__ = [
+ "_download_raw_dataset",
+ "AddTokens",
+ "compute_sha256",
+ "construct_image_from_string",
+ "DATA_DIRNAME",
+ "download_url",
+ "EmnistDataset",
+ "EmnistMapper",
+ "EmnistLinesDataset",
+ "get_samples_by_character",
+ "load_metadata",
+ "IamDataset",
+ "IamLinesDataset",
+ "IamParagraphsDataset",
+ "Preprocessor",
+ "Transpose",
+]
diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py
new file mode 100644
index 0000000..e794605
--- /dev/null
+++ b/text_recognizer/datasets/dataset.py
@@ -0,0 +1,152 @@
+"""Abstract dataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils import data
+from torchvision.transforms import ToTensor
+
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class Dataset(data.Dataset):
+ """Abstract class for with common methods for all datasets."""
+
+ def __init__(
+ self,
+ train: bool,
+ subsample_fraction: float = None,
+ transform: Optional[List[Dict]] = None,
+ target_transform: Optional[List[Dict]] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Initialization of Dataset class.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
+ transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
+ target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): Only use lower case letters. Defaults to False.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+ self.train = train
+ self.split = "train" if self.train else "test"
+
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
+ )
+ self._input_shape = self._mapper.input_shape
+ self._output_shape = self._mapper._num_classes
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = self._configure_transform(transform)
+ self.target_transform = self._configure_target_transform(target_transform)
+
+ self._data = None
+ self._targets = None
+
+ def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
+ transform_list = []
+ if transform is not None:
+ for t in transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ transform_list.append(getattr(transforms, t_type)(**t_args))
+ else:
+ transform_list.append(ToTensor())
+ return transforms.Compose(transform_list)
+
+ def _configure_target_transform(
+ self, target_transform: List[Dict]
+ ) -> transforms.Compose:
+ target_transform_list = [torch.tensor]
+ if target_transform is not None:
+ for t in target_transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ target_transform_list.append(getattr(transforms, t_type)(**t_args))
+ return transforms.Compose(target_transform_list)
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self._data = self.data[:num_subsample]
+ self._targets = self.targets[:num_subsample]
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ raise NotImplementedError
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Raises:
+ NotImplementedError: If the method is not implemented in child class.
+
+ """
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ raise NotImplementedError
diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py
new file mode 100644
index 0000000..9884fdf
--- /dev/null
+++ b/text_recognizer/datasets/emnist_dataset.py
@@ -0,0 +1,131 @@
+"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
+
+import json
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+from torchvision.datasets import EMNIST
+from torchvision.transforms import Compose, ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.transforms import Transpose
+from text_recognizer.datasets.util import DATA_DIRNAME
+
+
+class EmnistDataset(Dataset):
+ """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
+
+ def __init__(
+ self,
+ pad_token: str = None,
+ train: bool = False,
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ seed: int = 4711,
+ ) -> None:
+ """Loads the dataset and the mappings.
+
+ Args:
+ pad_token (str): The pad token symbol. Defaults to _.
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ seed (int): Seed number. Defaults to 4711.
+
+ """
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ pad_token=pad_token,
+ )
+
+ self.sample_to_balance = sample_to_balance
+
+ # Have to transpose the emnist characters, ToTensor norms input between [0,1].
+ if transform is None:
+ self.transform = Compose([Transpose(), ToTensor()])
+
+ self.target_transform = None
+
+ self.seed = seed
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, Tensor]): The indices of the samples to fetch.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target tuple.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Dataset\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Input shape: {self.input_shape}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ )
+
+ def _sample_to_balance(self) -> None:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = self._data
+ y = self._targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ self._data = x_sampled
+ self._targets = y_sampled
+
+ def load_or_generate_data(self) -> None:
+ """Fetch the EMNIST dataset."""
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=self.train,
+ download=False,
+ transform=None,
+ target_transform=None,
+ )
+
+ self._data = dataset.data
+ self._targets = dataset.targets
+
+ if self.sample_to_balance:
+ self._sample_to_balance()
+
+ if self.subsample_fraction is not None:
+ self._subsample()
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
new file mode 100644
index 0000000..2a0648a
--- /dev/null
+++ b/text_recognizer/datasets/emnist_essentials.json
@@ -0,0 +1 @@
+{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]}
diff --git a/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py
new file mode 100644
index 0000000..1992446
--- /dev/null
+++ b/text_recognizer/datasets/emnist_lines_dataset.py
@@ -0,0 +1,359 @@
+"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters."""
+
+from collections import defaultdict
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+from text_recognizer.datasets.util import (
+ DATA_DIRNAME,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
+
+DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+
+MAX_WIDTH = 952
+
+
+class EmnistLinesDataset(Dataset):
+ """Synthetic dataset of lines from the Brown corpus with Emnist characters."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ subsample_fraction: float = None,
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_samples: int = 10000,
+ seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Set attributes and loads the dataset.
+
+ Args:
+ train (bool): Flag for the filename. Defaults to False. Defaults to None.
+ transform (Optional[Callable]): The transform of the data. Defaults to None.
+ target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
+ max_length (int): The maximum number of characters. Defaults to 34.
+ min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
+ max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
+ num_samples (int): Number of samples to generate. Defaults to 10000.
+ seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
+
+ """
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ transform=transform,
+ target_transform=target_transform,
+ subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=self.pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ # Extract dataset information.
+ self._input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
+
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_samples = num_samples
+ self._input_shape = (
+ self.input_shape[0],
+ self.input_shape[1] * self.max_length,
+ )
+ self._output_shape = (self.max_length, self.num_classes)
+ self.seed = seed
+
+ # Placeholders for the dataset.
+ self._data = None
+ self._target = None
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Lines Dataset\n" # pylint: disable=no-member
+ f"Max length: {self.max_length}\n"
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Input shape: {self.input_shape}\n"
+ f"Data: {self.data.shape}\n"
+ f"Tagets: {self.targets.shape}\n"
+ )
+
+ @property
+ def data_filename(self) -> Path:
+ """Path to the h5 file."""
+ filename = "train.pt" if self.train else "test.pt"
+ return DATA_DIRNAME / filename
+
+ def load_or_generate_data(self) -> None:
+ """Loads the dataset, if it does not exist a new dataset is generated before loading it."""
+ np.random.seed(self.seed)
+
+ if not self.data_filename.exists():
+ self._generate_data()
+ self._load_data()
+ self._subsample()
+
+ def _load_data(self) -> None:
+ """Loads the dataset from the h5 file."""
+ logger.debug("EmnistLinesDataset loading data from HDF5...")
+ with h5py.File(self.data_filename, "r") as f:
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
+
+ def _generate_data(self) -> str:
+ """Generates a dataset with the Brown corpus and Emnist characters."""
+ logger.debug("Generating data...")
+
+ sentence_generator = SentenceGenerator(self.max_length)
+
+ # Load emnist dataset.
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
+ emnist.load_or_generate_data()
+
+ samples_by_character = get_samples_by_character(
+ emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
+ )
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "a") as f:
+ data, targets = create_dataset_of_images(
+ self.num_samples,
+ samples_by_character,
+ sentence_generator,
+ self.min_overlap,
+ self.max_overlap,
+ )
+
+ targets = convert_strings_to_categorical_labels(
+ targets, emnist.inverse_mapping
+ )
+
+ f.create_dataset("data", data=data, dtype="u1", compression="lzf")
+ f.create_dataset("targets", data=targets, dtype="u1", compression="lzf")
+
+
+def get_samples_by_character(
+ samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
+ """Creates a dictionary with character as key and value as the list of images of that character.
+
+ Args:
+ samples (np.ndarray): Dataset of images of characters.
+ labels (np.ndarray): The labels for each image.
+ mapping (Dict): The Emnist mapping dictionary.
+
+ Returns:
+ defaultdict: A dictionary with characters as keys and list of images as values.
+
+ """
+ samples_by_character = defaultdict(list)
+ for sample, label in zip(samples, labels.flatten()):
+ samples_by_character[mapping[label]].append(sample)
+ return samples_by_character
+
+
+def select_letter_samples_for_string(
+ string: str, samples_by_character: Dict
+) -> List[np.ndarray]:
+ """Randomly selects Emnist characters to use for the senetence.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+
+ Returns:
+ List[np.ndarray]: A list of emnist images of the string.
+
+ """
+ zero_image = np.zeros((28, 28), np.uint8)
+ sample_image_by_character = {}
+ for character in string:
+ if character in sample_image_by_character:
+ continue
+ samples = samples_by_character[character]
+ sample = samples[np.random.choice(len(samples))] if samples else zero_image
+ sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1)
+ return [sample_image_by_character[character] for character in string]
+
+
+def construct_image_from_string(
+ string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float
+) -> np.ndarray:
+ """Concatenates images of the characters in the string.
+
+ The concatination is made with randomly selected overlap so that some portion of the character will overlap.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ np.ndarray: The Emnist image of the string.
+
+ """
+ overlap = np.random.uniform(min_overlap, max_overlap)
+ sampled_images = select_letter_samples_for_string(string, samples_by_character)
+ length = len(sampled_images)
+ height, width = sampled_images[0].shape
+ next_overlap_width = width - int(overlap * width)
+ concatenated_image = np.zeros((height, width * length), np.uint8)
+ x = 0
+ for image in sampled_images:
+ concatenated_image[:, x : (x + width)] += image
+ x += next_overlap_width
+
+ if concatenated_image.shape[-1] > MAX_WIDTH:
+ concatenated_image = Tensor(concatenated_image).unsqueeze(0)
+ concatenated_image = F.interpolate(
+ concatenated_image, size=MAX_WIDTH, mode="nearest"
+ )
+ concatenated_image = concatenated_image.squeeze(0).numpy()
+
+ return np.minimum(255, concatenated_image)
+
+
+def create_dataset_of_images(
+ length: int,
+ samples_by_character: Dict,
+ sentence_generator: SentenceGenerator,
+ min_overlap: float,
+ max_overlap: float,
+) -> Tuple[np.ndarray, List[str]]:
+ """Creates a dataset with images and labels from strings generated from the SentenceGenerator.
+
+ Args:
+ length (int): The number of characters for each string.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ sentence_generator (SentenceGenerator): A SentenceGenerator objest.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels).
+
+ Raises:
+ RuntimeError: If the sentence generator is not able to generate a string.
+
+ """
+ sample_label = sentence_generator.generate()
+ sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0)
+ images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8)
+ labels = []
+ for n in range(length):
+ label = None
+ # Try several times to generate before actually throwing an error.
+ for _ in range(10):
+ try:
+ label = sentence_generator.generate()
+ break
+ except Exception: # pylint: disable=broad-except
+ pass
+ if label is None:
+ raise RuntimeError("Was not able to generate a valid string.")
+ images[n] = construct_image_from_string(
+ label, samples_by_character, min_overlap, max_overlap
+ )
+ labels.append(label)
+ return images, labels
+
+
+def convert_strings_to_categorical_labels(
+ labels: List[str], mapping: Dict
+) -> np.ndarray:
+ """Translates a string of characters in to a target array of class int."""
+ return np.array([[mapping[c] for c in label] for label in labels])
+
+
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
+def create_datasets(
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_train: int = 10000,
+ num_test: int = 1000,
+) -> None:
+ """Creates a training an validation dataset of Emnist lines."""
+ num_samples = [num_train, num_test]
+ for num, train in zip(num_samples, [True, False]):
+ emnist_lines = EmnistLinesDataset(
+ train=train,
+ max_length=max_length,
+ min_overlap=min_overlap,
+ max_overlap=max_overlap,
+ num_samples=num,
+ )
+ emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py
new file mode 100644
index 0000000..a8998b9
--- /dev/null
+++ b/text_recognizer/datasets/iam_dataset.py
@@ -0,0 +1,133 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from typing import Any, Dict, List
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+import defusedxml.ElementTree as ET
+from loguru import logger
+import toml
+
+from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
+
+RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
+
+
+class IamDataset:
+ """IAM dataset.
+
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+
+ """
+
+ def __init__(self) -> None:
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def load_or_generate_data(self) -> None:
+ """Downloads IAM dataset if xml files does not exist."""
+ if not self.xml_filenames:
+ self._download_iam()
+
+ @property
+ def xml_filenames(self) -> List:
+ """List of xml filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List:
+ """List of forms filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ def _download_iam(self) -> None:
+ curdir = os.getcwd()
+ os.chdir(RAW_DATA_DIRNAME)
+ _download_raw_dataset(self.metadata)
+ _extract_raw_dataset(self.metadata)
+ os.chdir(curdir)
+
+ @property
+ def form_filenames_by_id(self) -> Dict:
+ """Creates a dictionary with filenames as keys and forms as values."""
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of line texts in it."""
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ def __repr__(self) -> str:
+ """Print info about dataset."""
+ return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
+
+
+def _extract_raw_dataset(metadata: Dict) -> None:
+ logger.info("Extracting IAM data.")
+ with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
+ zip_file.extractall()
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace &quot; with "."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get the line region dict for each line."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
+ """Extracts coordinates for each line of text."""
+ # TODO: fix input!
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def main() -> None:
+ """Initializes the dataset and print info about the dataset."""
+ dataset = IamDataset()
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py
new file mode 100644
index 0000000..1cb84bd
--- /dev/null
+++ b/text_recognizer/datasets/iam_lines_dataset.py
@@ -0,0 +1,110 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:] if self.data is not None else None
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py
new file mode 100644
index 0000000..8ba5142
--- /dev/null
+++ b/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -0,0 +1,291 @@
+"""IamParagraphsDataset class and functions for data processing."""
+import random
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import cv2
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer import util
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.iam_dataset import IamDataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
+DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
+CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
+GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
+
+PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
+TEST_FRACTION = 0.2
+SEED = 4711
+
+
+class IamParagraphsDataset(Dataset):
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
+ # Load Iam dataset.
+ self.iam_dataset = IamDataset()
+
+ self.num_classes = 3
+ self._input_shape = (256, 256)
+ self._output_shape = self._input_shape + (self.num_classes,)
+ self._ids = None
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ seed = np.random.randint(SEED)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.transform:
+ data = self.transform(data)
+
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets.long()
+
+ @property
+ def ids(self) -> Tensor:
+ """Ids of the dataset."""
+ return self._ids
+
+ def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
+ """Get data target pair from id."""
+ ind = self.ids.index(id_)
+ return self.data[ind], self.targets[ind]
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
+ num_targets = len(self.iam_dataset.line_regions_by_id)
+
+ if num_actual < num_targets - 2:
+ self._process_iam_paragraphs()
+
+ self._data, self._targets, self._ids = _load_iam_paragraphs()
+ self._get_random_split()
+ self._subsample()
+
+ def _get_random_split(self) -> None:
+ np.random.seed(SEED)
+ num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
+ indices = np.random.permutation(self.data.shape[0])
+ train_indices, test_indices = indices[:num_train], indices[num_train:]
+ if self.train:
+ self._data = self.data[train_indices]
+ self._targets = self.targets[train_indices]
+ else:
+ self._data = self.data[test_indices]
+ self._targets = self.targets[test_indices]
+
+ def _process_iam_paragraphs(self) -> None:
+ """Crop the part with the text.
+
+ For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
+ self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
+ corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
+ """
+ crop_dims = self._decide_on_crop_dims()
+ CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ GT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
+ )
+ for filename in self.iam_dataset.form_filenames:
+ id_ = filename.stem
+ line_region = self.iam_dataset.line_regions_by_id[id_]
+ _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
+
+ def _decide_on_crop_dims(self) -> Tuple[int, int]:
+ """Decide on the dimensions to crop out of the form image.
+
+ Since image width is larger than a comfortable crop around the longest paragraph,
+ we will make the crop a square form factor.
+ And since the found dimensions 610x610 are pretty close to 512x512,
+ we might as well resize crops and make it exactly that, which lets us
+ do all kinds of power-of-2 pooling and upsampling should we choose to.
+
+ Returns:
+ Tuple[int, int]: A tuple of crop dimensions.
+
+ Raises:
+ RuntimeError: When max crop height is larger than max crop width.
+
+ """
+
+ sample_form_filename = self.iam_dataset.form_filenames[0]
+ sample_image = util.read_image(sample_form_filename, grayscale=True)
+ max_crop_width = sample_image.shape[1]
+ max_crop_height = _get_max_paragraph_crop_height(
+ self.iam_dataset.line_regions_by_id
+ )
+ if not max_crop_height <= max_crop_width:
+ raise RuntimeError(
+ f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
+ )
+
+ crop_dims = (max_crop_width, max_crop_width)
+ logger.info(
+ f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
+ )
+ logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
+ return crop_dims
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return (
+ "IAM Paragraph Dataset\n" # pylint: disable=no-member
+ f"Num classes: {self.num_classes}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+
+def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
+ heights = []
+ for regions in line_regions_by_id.values():
+ min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ heights.append(height)
+ return max(heights)
+
+
+def _crop_paragraph_image(
+ filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
+) -> None:
+ image = util.read_image(filename, grayscale=True)
+
+ min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ crop_height = crop_dims[0]
+ buffer = (crop_height - height) // 2
+
+ # Generate image crop.
+ image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
+ try:
+ image_crop[buffer : buffer + height] = image[min_y1:max_y2]
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(f"Rescued {filename}: {e}")
+ return
+
+ # Generate ground truth.
+ gt_image = np.zeros_like(image_crop, dtype=np.uint8)
+ for index, region in enumerate(line_regions):
+ gt_image[
+ (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
+ region["x1"] : region["x2"],
+ ] = (index % 2 + 1)
+
+ # Generate image for debugging.
+ import matplotlib.pyplot as plt
+
+ cmap = plt.get_cmap("Set1")
+ image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
+ for index, region in enumerate(line_regions):
+ color = [255 * _ for _ in cmap(index)[:-1]]
+ cv2.rectangle(
+ image_crop_for_debug,
+ (region["x1"], region["y1"] - min_y1 + buffer),
+ (region["x2"], region["y2"] - min_y1 + buffer),
+ color,
+ 3,
+ )
+ image_crop_for_debug = cv2.resize(
+ image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
+ )
+ util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
+ util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
+ util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
+
+
+def _load_iam_paragraphs() -> None:
+ logger.info("Loading IAM paragraph crops and ground truth from image files...")
+ images = []
+ gt_images = []
+ ids = []
+ for filename in CROPS_DIRNAME.glob("*.jpg"):
+ id_ = filename.stem
+ image = util.read_image(filename, grayscale=True)
+ image = 1.0 - image / 255
+
+ gt_filename = GT_DIRNAME / f"{id_}.png"
+ gt_image = util.read_image(gt_filename, grayscale=True)
+
+ images.append(image)
+ gt_images.append(gt_image)
+ ids.append(id_)
+ images = np.array(images).astype(np.float32)
+ gt_images = np.array(gt_images).astype(np.uint8)
+ ids = np.array(ids)
+ return images, gt_images, ids
+
+
+@click.command()
+@click.option(
+ "--subsample_fraction",
+ type=float,
+ default=None,
+ help="The subsampling factor of the dataset.",
+)
+def main(subsample_fraction: float) -> None:
+ """Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
+ dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py
new file mode 100644
index 0000000..a93eb00
--- /dev/null
+++ b/text_recognizer/datasets/iam_preprocessor.py
@@ -0,0 +1,196 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ # TODO: add lower case only to when generating...
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ self.wordsep = "▁"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+
+ self.data_dir = Path(data_dir)
+
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ line = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ line = itertools.chain([self.wordsep], line)
+ return torch.LongTensor([token_to_index[t] for t in line])
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ logger.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ logger.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ logger.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
new file mode 100644
index 0000000..dd76652
--- /dev/null
+++ b/text_recognizer/datasets/sentence_generator.py
@@ -0,0 +1,81 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.datasets.util import DATA_DIRNAME
+
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+
+
+class SentenceGenerator:
+ """Generates text sentences using the Brown corpus."""
+
+ def __init__(self, max_length: Optional[int] = None) -> None:
+ """Loads the corpus and sets word start indices."""
+ self.corpus = brown_corpus()
+ self.word_start_indices = [0] + [
+ _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+ ]
+ self.max_length = max_length
+
+ def generate(self, max_length: Optional[int] = None) -> str:
+ """Generates a word or sentences from the Brown corpus.
+
+ Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
+ max_length with the '_' characters if sentence is shorter.
+
+ Args:
+ max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
+
+ Returns:
+ str: A sentence from the Brown corpus.
+
+ Raises:
+ ValueError: If max_length was not specified at initialization and not given as an argument.
+
+ """
+ if max_length is None:
+ max_length = self.max_length
+ if max_length is None:
+ raise ValueError(
+ "Must provide max_length to this method or when making this object."
+ )
+
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ padding = "_" * (max_length - len(sampled_text))
+ return sampled_text + padding
+
+
+def brown_corpus() -> str:
+ """Returns a single string with the Brown corpus with all punctuations stripped."""
+ sentences = load_nltk_brown_corpus()
+ corpus = " ".join(itertools.chain.from_iterable(sentences))
+ corpus = corpus.translate({ord(c): None for c in string.punctuation})
+ corpus = re.sub(" +", " ", corpus)
+ return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+ """Load the Brown corpus using the NLTK library."""
+ nltk.data.path.append(NLTK_DATA_DIRNAME)
+ try:
+ nltk.corpus.brown.sents()
+ except LookupError:
+ NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+ return nltk.corpus.brown.sents()
diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py
new file mode 100644
index 0000000..b6a48f5
--- /dev/null
+++ b/text_recognizer/datasets/transforms.py
@@ -0,0 +1,266 @@
+"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
+import random
+from typing import Any, Optional, Union
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision import transforms
+from torchvision.transforms import (
+ ColorJitter,
+ Compose,
+ Normalize,
+ RandomAffine,
+ RandomHorizontalFlip,
+ RandomRotation,
+ ToPILImage,
+ ToTensor,
+)
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class RandomResizeCrop:
+ """Image transform with random resize and crop applied.
+
+ Stolen from
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+ """
+
+ def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
+ self.jitter = jitter
+ self.ratio = ratio
+
+ def __call__(self, img: np.ndarray) -> np.ndarray:
+ """Applies random crop and rotation to an image."""
+ w, h = img.size
+
+ # pad with white:
+ img = transforms.functional.pad(img, self.jitter, fill=255)
+
+ # crop at random (x, y):
+ x = self.jitter + random.randint(-self.jitter, self.jitter)
+ y = self.jitter + random.randint(-self.jitter, self.jitter)
+
+ # randomize aspect ratio:
+ size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
+ size = (h, int(size_w))
+ img = transforms.functional.resized_crop(img, y, x, h, w, size)
+ return img
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
+
+
+class Resize:
+ """Resizes a tensor to a specified width."""
+
+ def __init__(self, width: int = 952) -> None:
+ # The default is 952 because of the IAM dataset.
+ self.width = width
+
+ def __call__(self, image: Tensor) -> Tensor:
+ """Resize tensor in the last dimension."""
+ return F.interpolate(image, size=self.width, mode="nearest")
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
+ return target
+
+
+class ApplyContrast:
+ """Sets everything below a threshold to zero, i.e. increase contrast."""
+
+ def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
+ self.low = low
+ self.high = high
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Apply mask binary mask to input tensor."""
+ mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
+ return x * mask
+
+
+class Unsqueeze:
+ """Add a dimension to the tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Adds dim."""
+ return x.unsqueeze(0)
+
+
+class Squeeze:
+ """Removes the first dimension of a tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Removes first dim."""
+ return x.squeeze(0)
+
+
+class ToLower:
+ """Converts target to lower case."""
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Corrects index value in target tensor."""
+ device = target.device
+ return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())
diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py
new file mode 100644
index 0000000..da87756
--- /dev/null
+++ b/text_recognizer/datasets/util.py
@@ -0,0 +1,209 @@
+"""Util functions for datasets."""
+import hashlib
+import json
+import os
+from pathlib import Path
+import string
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.datasets import EMNIST
+from tqdm import tqdm
+
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+
+
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
+
+
+def download_emnist() -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
+
+
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ self.lower = lower
+
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset information.
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self) -> None:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol, and acts as blank symbol as well.
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
+
+ max_key = max(self.mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ self._mapping = {**self.mapping, **extra_mapping}
+
+
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
diff --git a/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py
new file mode 100644
index 0000000..8e348fe
--- /dev/null
+++ b/text_recognizer/line_predictor.py
@@ -0,0 +1,28 @@
+"""LinePredictor class."""
+import importlib
+from typing import Tuple, Union
+
+import numpy as np
+from torch import nn
+
+from text_recognizer import datasets, networks
+from text_recognizer.models import TransformerModel
+from text_recognizer.util import read_image
+
+
+class LinePredictor:
+ """Given an image of a line of handwritten text, recognizes the text content."""
+
+ def __init__(self, dataset: str, network_fn: str) -> None:
+ network_fn = getattr(networks, network_fn)
+ dataset = getattr(datasets, dataset)
+ self.model = TransformerModel(network_fn=network_fn, dataset=dataset)
+ self.model.eval()
+
+ def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]:
+ """Predict on a single images contianing a handwritten character."""
+ if isinstance(image_or_filename, str):
+ image = read_image(image_or_filename, grayscale=True)
+ else:
+ image = image_or_filename
+ return self.model.predict_on_image(image)
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
new file mode 100644
index 0000000..7647d7e
--- /dev/null
+++ b/text_recognizer/models/__init__.py
@@ -0,0 +1,18 @@
+"""Model modules."""
+from .base import Model
+from .character_model import CharacterModel
+from .crnn_model import CRNNModel
+from .ctc_transformer_model import CTCTransformerModel
+from .segmentation_model import SegmentationModel
+from .transformer_model import TransformerModel
+from .vqvae_model import VQVAEModel
+
+__all__ = [
+ "CharacterModel",
+ "CRNNModel",
+ "CTCTransformerModel",
+ "Model",
+ "SegmentationModel",
+ "TransformerModel",
+ "VQVAEModel",
+]
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
new file mode 100644
index 0000000..70f4cdb
--- /dev/null
+++ b/text_recognizer/models/base.py
@@ -0,0 +1,455 @@
+"""Abstract Model class for PyTorch neural networks."""
+
+from abc import ABC, abstractmethod
+from glob import glob
+import importlib
+from pathlib import Path
+import re
+import shutil
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+from loguru import logger
+import torch
+from torch import nn
+from torch import Tensor
+from torch.optim.swa_utils import AveragedModel, SWALR
+from torch.utils.data import DataLoader, Dataset, random_split
+from torchsummary import summary
+
+from text_recognizer import datasets
+from text_recognizer import networks
+from text_recognizer.datasets import EmnistMapper
+
+WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
+
+
+class Model(ABC):
+ """Abstract Model class with composition of different parts defining a PyTorch neural network."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Base class, to be inherited by model for specific type of data.
+
+ Args:
+ network_fn (str): The name of network.
+ dataset (str): The name dataset class.
+ network_args (Optional[Dict]): Arguments for the network. Defaults to None.
+ dataset_args (Optional[Dict]): Arguments for the dataset.
+ metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
+ criterion (Optional[Callable]): The criterion to evaluate the performance of the network.
+ Defaults to None.
+ criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None.
+ optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None.
+ optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None.
+ lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None.
+ lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to
+ None.
+ swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to
+ None.
+ device (Optional[str]): Name of the device to train on. Defaults to None.
+
+ """
+ self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}"
+ # Has to be set in subclass.
+ self._mapper = None
+
+ # Placeholder.
+ self._input_shape = None
+
+ self.dataset_name = dataset
+ self.dataset = None
+ self.dataset_args = dataset_args
+
+ # Placeholders for datasets.
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ # Stochastic Weight Averaging placeholders.
+ self.swa_args = swa_args
+ self._swa_scheduler = None
+ self._swa_network = None
+ self._use_swa_model = False
+
+ # Experiment directory.
+ self.model_dir = None
+
+ # Flag for configured model.
+ self.is_configured = False
+ self.data_prepared = False
+
+ # Flag for stopping training.
+ self.stop_training = False
+
+ self._metrics = metrics if metrics is not None else None
+
+ # Set the device.
+ self._device = (
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ if device is None
+ else device
+ )
+
+ # Configure network.
+ self._network = None
+ self._network_args = network_args
+ self._configure_network(network_fn)
+
+ # Place network on device (GPU).
+ self.to_device()
+
+ # Loss and Optimizer placeholders for before loading.
+ self._criterion = criterion
+ self.criterion_args = criterion_args
+
+ self._optimizer = optimizer
+ self.optimizer_args = optimizer_args
+
+ self._lr_scheduler = lr_scheduler
+ self.lr_scheduler_args = lr_scheduler_args
+
+ def configure_model(self) -> None:
+ """Configures criterion and optimizers."""
+ if not self.is_configured:
+ self._configure_criterion()
+ self._configure_optimizers()
+
+ # Set this flag to true to prevent the model from configuring again.
+ self.is_configured = True
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ # TODO add downloading.
+ if not self.data_prepared:
+ # Load dataset module.
+ self.dataset = getattr(datasets, self.dataset_name)
+
+ # Load train dataset.
+ train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+ train_dataset.load_or_generate_data()
+
+ # Set input shape.
+ self._input_shape = train_dataset.input_shape
+
+ # Split train dataset into a training and validation partition.
+ dataset_len = len(train_dataset)
+ train_len = int(
+ self.dataset_args["train_args"]["train_fraction"] * dataset_len
+ )
+ val_len = dataset_len - train_len
+ self.train_dataset, self.val_dataset = random_split(
+ train_dataset, lengths=[train_len, val_len]
+ )
+
+ # Load test dataset.
+ self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+ self.test_dataset.load_or_generate_data()
+
+ # Set the flag to true to disable ability to load data again.
+ self.data_prepared = True
+
+ def train_dataloader(self) -> DataLoader:
+ """Returns data loader for training set."""
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Returns data loader for validation set."""
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """Returns data loader for test set."""
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.dataset_args["train_args"]["batch_size"],
+ num_workers=self.dataset_args["train_args"]["num_workers"],
+ shuffle=False,
+ pin_memory=True,
+ )
+
+ def _configure_network(self, network_fn: Type[nn.Module]) -> None:
+ """Loads the network."""
+ # If no network arguments are given, load pretrained weights if they exist.
+ # Load network module.
+ network_fn = getattr(networks, network_fn)
+ if self._network_args is None:
+ self.load_weights(network_fn)
+ else:
+ self._network = network_fn(**self._network_args)
+
+ def _configure_criterion(self) -> None:
+ """Loads the criterion."""
+ self._criterion = (
+ self._criterion(**self.criterion_args)
+ if self._criterion is not None
+ else None
+ )
+
+ def _configure_optimizers(self,) -> None:
+ """Loads the optimizers."""
+ if self._optimizer is not None:
+ self._optimizer = self._optimizer(
+ self._network.parameters(), **self.optimizer_args
+ )
+ else:
+ self._optimizer = None
+
+ if self._optimizer and self._lr_scheduler is not None:
+ if "steps_per_epoch" in self.lr_scheduler_args:
+ self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader())
+
+ # Assume lr scheduler should update at each epoch if not specified.
+ if "interval" not in self.lr_scheduler_args:
+ interval = "epoch"
+ else:
+ interval = self.lr_scheduler_args.pop("interval")
+ self._lr_scheduler = {
+ "lr_scheduler": self._lr_scheduler(
+ self._optimizer, **self.lr_scheduler_args
+ ),
+ "interval": interval,
+ }
+
+ if self.swa_args is not None:
+ self._swa_scheduler = {
+ "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]),
+ "swa_start": self.swa_args["start"],
+ }
+ self._swa_network = AveragedModel(self._network).to(self.device)
+
+ @property
+ def name(self) -> str:
+ """Returns the name of the model."""
+ return self._name
+
+ @property
+ def input_shape(self) -> Tuple[int, ...]:
+ """The input shape."""
+ return self._input_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the mapper that maps between ints and chars."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between network output and Emnist character."""
+ return self._mapper.mapping if self._mapper is not None else None
+
+ def eval(self) -> None:
+ """Sets the network to evaluation mode."""
+ self._network.eval()
+
+ def train(self) -> None:
+ """Sets the network to train mode."""
+ self._network.train()
+
+ @property
+ def device(self) -> str:
+ """Device where the weights are stored, i.e. cpu or cuda."""
+ return self._device
+
+ @property
+ def metrics(self) -> Optional[Dict]:
+ """Metrics."""
+ return self._metrics
+
+ @property
+ def criterion(self) -> Optional[Callable]:
+ """Criterion."""
+ return self._criterion
+
+ @property
+ def optimizer(self) -> Optional[Callable]:
+ """Optimizer."""
+ return self._optimizer
+
+ @property
+ def lr_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the learning rate scheduler."""
+ return self._lr_scheduler
+
+ @property
+ def swa_scheduler(self) -> Optional[Dict]:
+ """Returns a directory with the stochastic weight averaging scheduler."""
+ return self._swa_scheduler
+
+ @property
+ def swa_network(self) -> Optional[Callable]:
+ """Returns the stochastic weight averaging network."""
+ return self._swa_network
+
+ @property
+ def network(self) -> Type[nn.Module]:
+ """Neural network."""
+ # Returns the SWA network if available.
+ return self._network
+
+ @property
+ def weights_filename(self) -> str:
+ """Filepath to the network weights."""
+ WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt")
+
+ def use_swa_model(self) -> None:
+ """Set to use predictions from SWA model."""
+ if self.swa_network is not None:
+ self._use_swa_model = True
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass with the network."""
+ if self._use_swa_model:
+ return self.swa_network(x)
+ else:
+ return self.network(x)
+
+ def summary(
+ self,
+ input_shape: Optional[Union[List, Tuple]] = None,
+ depth: int = 3,
+ device: Optional[str] = None,
+ ) -> None:
+ """Prints a summary of the network architecture."""
+ device = self.device if device is None else device
+
+ if input_shape is not None:
+ summary(self.network, input_shape, depth=depth, device=device)
+ elif self._input_shape is not None:
+ input_shape = tuple(self._input_shape)
+ summary(self.network, input_shape, depth=depth, device=device)
+ else:
+ logger.warning("Could not print summary as input shape is not set.")
+
+ def to_device(self) -> None:
+ """Places the network on the device (GPU)."""
+ self._network.to(self._device)
+
+ def _get_state_dict(self) -> Dict:
+ """Get the state dict of the model."""
+ state = {"model_state": self._network.state_dict()}
+
+ if self._optimizer is not None:
+ state["optimizer_state"] = self._optimizer.state_dict()
+
+ if self._lr_scheduler is not None:
+ state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict()
+ state["scheduler_interval"] = self._lr_scheduler["interval"]
+
+ if self._swa_network is not None:
+ state["swa_network"] = self._swa_network.state_dict()
+
+ return state
+
+ def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None:
+ """Load a previously saved checkpoint.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+
+ """
+ checkpoint_path = Path(checkpoint_path)
+ self.prepare_data()
+ self.configure_model()
+ logger.debug("Loading checkpoint...")
+ if not checkpoint_path.exists():
+ logger.debug("File does not exist {str(checkpoint_path)}")
+
+ checkpoint = torch.load(str(checkpoint_path), map_location=self.device)
+ self._network.load_state_dict(checkpoint["model_state"])
+
+ if self._optimizer is not None:
+ self._optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ if self._lr_scheduler is not None:
+ # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs
+ # with OneCycleLR.
+ if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR":
+ self._lr_scheduler["lr_scheduler"].load_state_dict(
+ checkpoint["scheduler_state"]
+ )
+ self._lr_scheduler["interval"] = checkpoint["scheduler_interval"]
+
+ if self._swa_network is not None:
+ self._swa_network.load_state_dict(checkpoint["swa_network"])
+
+ def save_checkpoint(
+ self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str
+ ) -> None:
+ """Saves a checkpoint of the model.
+
+ Args:
+ checkpoint_path (Path): Path to the experiment with the checkpoint.
+ is_best (bool): If it is the currently best model.
+ epoch (int): The epoch of the checkpoint.
+ val_metric (str): Validation metric.
+
+ """
+ state = self._get_state_dict()
+ state["is_best"] = is_best
+ state["epoch"] = epoch
+ state["network_args"] = self._network_args
+
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
+
+ logger.debug("Saving checkpoint...")
+ filepath = str(checkpoint_path / "last.pt")
+ torch.save(state, filepath)
+
+ if is_best:
+ logger.debug(
+ f"Found a new best {val_metric}. Saving best checkpoint and weights."
+ )
+ shutil.copyfile(filepath, str(checkpoint_path / "best.pt"))
+
+ def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None:
+ """Load the network weights."""
+ logger.debug("Loading network with pretrained weights.")
+ filename = glob(self.weights_filename)[0]
+ if not filename:
+ raise FileNotFoundError(
+ f"Could not find any pretrained weights at {self.weights_filename}"
+ )
+ # Loading state directory.
+ state_dict = torch.load(filename, map_location=torch.device(self._device))
+ self._network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ if network_fn is not None:
+ self._network = network_fn(**self._network_args)
+ self._network.load_state_dict(weights)
+
+ if "swa_network" in state_dict:
+ self._swa_network = AveragedModel(self._network).to(self.device)
+ self._swa_network.load_state_dict(state_dict["swa_network"])
+
+ def save_weights(self, path: Path) -> None:
+ """Save the network weights."""
+ logger.debug("Saving the best network weights.")
+ shutil.copyfile(str(path / "best.pt"), self.weights_filename)
diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py
new file mode 100644
index 0000000..f9944f3
--- /dev/null
+++ b/text_recognizer/models/character_model.py
@@ -0,0 +1,88 @@
+"""Defines the CharacterModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class CharacterModel(Model):
+ """Model for predicting characters from images."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(
+ self, image: Union[np.ndarray, torch.Tensor]
+ ) -> Tuple[str, float]:
+ """Character prediction on an image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ logits = self.forward(image)
+
+ prediction = self.softmax(logits.squeeze(0))
+
+ index = int(torch.argmax(prediction, dim=0))
+ confidence_of_prediction = prediction[index]
+ predicted_character = self.mapper(index)
+
+ return predicted_character, confidence_of_prediction
diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py
new file mode 100644
index 0000000..1e01a83
--- /dev/null
+++ b/text_recognizer/models/crnn_model.py
@@ -0,0 +1,119 @@
+"""Defines the CRNNModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CRNNModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 79]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=79,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py
new file mode 100644
index 0000000..25925f2
--- /dev/null
+++ b/text_recognizer/models/ctc_transformer_model.py
@@ -0,0 +1,120 @@
+"""Defines the CTC Transformer Model class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class CTCTransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.lower = dataset_args["args"]["lower"]
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,)
+
+ self.tensor_transform = ToTensor()
+
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss.
+
+ Args:
+ output (Tensor): Model predictions.
+ targets (Tensor): Correct output sequence.
+
+ Returns:
+ Tensor: The CTC loss.
+
+ """
+ # Input lengths on the form [T, B]
+ input_lengths = torch.full(
+ size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,
+ )
+
+ # Configure target tensors for ctc loss.
+ targets_ = Tensor([]).to(self.device)
+ target_lengths = []
+ for t in targets:
+ # Remove padding symbol as it acts as the blank symbol.
+ t = t[t < 53]
+ targets_ = torch.cat([targets_, t])
+ target_lengths.append(len(t))
+
+ targets = targets_.type(dtype=torch.long)
+ target_lengths = (
+ torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
+ )
+
+ return self._criterion(output, targets, input_lengths, target_lengths)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ log_probs = self.forward(image)
+
+ raw_pred, _ = greedy_decoder(
+ predictions=log_probs,
+ character_mapper=self.mapper,
+ blank_label=53,
+ collapse_repeated=True,
+ )
+
+ log_probs, _ = log_probs.max(dim=2)
+
+ predicted_characters = "".join(raw_pred[0])
+ confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item()
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py
new file mode 100644
index 0000000..613108a
--- /dev/null
+++ b/text_recognizer/models/segmentation_model.py
@@ -0,0 +1,75 @@
+"""Segmentation model for detecting and segmenting lines."""
+from typing import Callable, Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.models.base import Model
+
+
+class SegmentationModel(Model):
+ """Model for segmenting lines in an image."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype is np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype is torch.uint8 or image.dtype is torch.int64:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ if not torch.is_tensor(image):
+ image = Tensor(image)
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ logits = self.forward(image)
+
+ segmentation_mask = torch.argmax(logits, dim=1)
+
+ return segmentation_mask
diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py
new file mode 100644
index 0000000..3f63053
--- /dev/null
+++ b/text_recognizer/models/transformer_model.py
@@ -0,0 +1,124 @@
+"""Defines the CNN-Transformer class."""
+from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+from torch.utils.data import Dataset
+
+from text_recognizer.datasets import EmnistMapper
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.models.base import Model
+from text_recognizer.networks import greedy_decoder
+
+
+class TransformerModel(Model):
+ """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer."""
+
+ def __init__(
+ self,
+ network_fn: str,
+ dataset: str,
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.init_token = dataset_args["args"]["init_token"]
+ self.pad_token = dataset_args["args"]["pad_token"]
+ self.eos_token = dataset_args["args"]["eos_token"]
+ self.lower = dataset_args["args"]["lower"]
+ self.max_len = 100
+
+ if self._mapper is None:
+ self._mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=self.lower,
+ )
+ self.tensor_transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])]
+ )
+ self.softmax = nn.Softmax(dim=2)
+
+ @torch.no_grad()
+ def _generate_sentence(self, image: Tensor) -> Tuple[List, float]:
+ src = self.network.extract_image_features(image)
+
+ # Added for vqvae transformer.
+ if isinstance(src, Tuple):
+ src = src[0]
+
+ memory = self.network.encoder(src)
+
+ confidence_of_predictions = []
+ trg_indices = [self.mapper(self.init_token)]
+
+ for _ in range(self.max_len - 1):
+ trg = torch.tensor(trg_indices, device=self.device)[None, :].long()
+ trg = self.network.target_embedding(trg)
+ logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None)
+
+ # Convert logits to probabilities.
+ probs = self.softmax(logits)
+
+ pred_token = probs.argmax(2)[:, -1].item()
+ confidence = probs.max(2).values[:, -1].item()
+
+ trg_indices.append(pred_token)
+ confidence_of_predictions.append(confidence)
+
+ if pred_token == self.mapper(self.eos_token):
+ break
+
+ confidence = np.min(confidence_of_predictions)
+ predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]])
+
+ return predicted_characters, confidence
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
+ """Predict on a single input."""
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+
+ # Rescale image between 0 and 1.
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+
+ (predicted_characters, confidence_of_prediction,) = self._generate_sentence(
+ image
+ )
+
+ return predicted_characters, confidence_of_prediction
diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py
new file mode 100644
index 0000000..70f6f1f
--- /dev/null
+++ b/text_recognizer/models/vqvae_model.py
@@ -0,0 +1,80 @@
+"""Defines the VQVAEModel class."""
+from typing import Callable, Dict, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets import EmnistMapper
+from text_recognizer.models.base import Model
+
+
+class VQVAEModel(Model):
+ """Model for reconstructing images from codebook."""
+
+ def __init__(
+ self,
+ network_fn: Type[nn.Module],
+ dataset: Type[Dataset],
+ network_args: Optional[Dict] = None,
+ dataset_args: Optional[Dict] = None,
+ metrics: Optional[Dict] = None,
+ criterion: Optional[Callable] = None,
+ criterion_args: Optional[Dict] = None,
+ optimizer: Optional[Callable] = None,
+ optimizer_args: Optional[Dict] = None,
+ lr_scheduler: Optional[Callable] = None,
+ lr_scheduler_args: Optional[Dict] = None,
+ swa_args: Optional[Dict] = None,
+ device: Optional[str] = None,
+ ) -> None:
+ """Initializes the CharacterModel."""
+
+ super().__init__(
+ network_fn,
+ dataset,
+ network_args,
+ dataset_args,
+ metrics,
+ criterion,
+ criterion_args,
+ optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
+ swa_args,
+ device,
+ )
+ self.pad_token = dataset_args["args"]["pad_token"]
+ if self._mapper is None:
+ self._mapper = EmnistMapper(pad_token=self.pad_token,)
+ self.tensor_transform = ToTensor()
+ self.softmax = nn.Softmax(dim=0)
+
+ @torch.no_grad()
+ def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """Reconstruction of image.
+
+ Args:
+ image (Union[np.ndarray, torch.Tensor]): An image containing a character.
+
+ Returns:
+ Tuple[str, float]: The predicted character and the confidence in the prediction.
+
+ """
+ self.eval()
+
+ if image.dtype == np.uint8:
+ # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
+ image = self.tensor_transform(image)
+ if image.dtype == torch.uint8:
+ # If the image is an unscaled tensor.
+ image = image.type("torch.FloatTensor") / 255
+
+ # Put the image tensor on the device the model weights are on.
+ image = image.to(self.device)
+ image_reconstructed, _ = self.forward(image)
+
+ return image_reconstructed
diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
new file mode 100644
index 0000000..1521355
--- /dev/null
+++ b/text_recognizer/networks/__init__.py
@@ -0,0 +1,43 @@
+"""Network modules."""
+from .cnn import CNN
+from .cnn_transformer import CNNTransformer
+from .crnn import ConvolutionalRecurrentNetwork
+from .ctc import greedy_decoder
+from .densenet import DenseNet
+from .lenet import LeNet
+from .metrics import accuracy, cer, wer
+from .mlp import MLP
+from .residual_network import ResidualNetwork, ResidualNetworkEncoder
+from .transducer import load_transducer_loss, TDS2d
+from .transformer import Transformer
+from .unet import UNet
+from .util import sliding_window
+from .vit import ViT
+from .vq_transformer import VQTransformer
+from .vqvae import VQVAE
+from .wide_resnet import WideResidualNetwork
+
+__all__ = [
+ "accuracy",
+ "cer",
+ "CNN",
+ "CNNTransformer",
+ "ConvolutionalRecurrentNetwork",
+ "DenseNet",
+ "FCN",
+ "greedy_decoder",
+ "MLP",
+ "LeNet",
+ "load_transducer_loss",
+ "ResidualNetwork",
+ "ResidualNetworkEncoder",
+ "sliding_window",
+ "UNet",
+ "TDS2d",
+ "Transformer",
+ "ViT",
+ "VQTransformer",
+ "VQVAE",
+ "wer",
+ "WideResidualNetwork",
+]
diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py
new file mode 100644
index 0000000..dccccdb
--- /dev/null
+++ b/text_recognizer/networks/beam.py
@@ -0,0 +1,83 @@
+"""Implementation of beam search decoder for a sequence to sequence network.
+
+Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py
+
+"""
+# from typing import List
+# from Queue import PriorityQueue
+
+# from loguru import logger
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import torch.nn.functional as F
+
+
+# class Node:
+# def __init__(
+# self, parent: Node, target_index: int, log_prob: Tensor, length: int
+# ) -> None:
+# self.parent = parent
+# self.target_index = target_index
+# self.log_prob = log_prob
+# self.length = length
+# self.reward = 0.0
+
+# def eval(self, alpha: float = 1.0) -> Tensor:
+# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward
+
+
+# @torch.no_grad()
+# def beam_decoder(
+# network, mapper, device, memory: Tensor = None, max_len: int = 97,
+# ) -> Tensor:
+# beam_width = 10
+# topk = 1 # How many sentences to generate.
+
+# trg_indices = [mapper(mapper.init_token)]
+
+# end_nodes = []
+
+# node = Node(None, trg_indices, 0, 1)
+# nodes = PriorityQueue()
+
+# nodes.put((node.eval(), node))
+# q_size = 1
+
+# # Beam search
+# for _ in range(max_len):
+# if q_size > 2000:
+# logger.warning("Could not decoder input")
+# break
+
+# # Fetch the best node.
+# score, n = nodes.get()
+# decoder_input = n.target_index
+
+# if n.target_index == mapper(mapper.eos_token) and n.parent is not None:
+# end_nodes.append((score, n))
+
+# # If we reached the maximum number of sentences required.
+# if len(end_nodes) >= 1:
+# break
+# else:
+# continue
+
+# # Forward pass with transformer.
+# trg = torch.tensor(trg_indices, device=device)[None, :].long()
+# trg = network.target_embedding(trg)
+# logits = network.decoder(trg=trg, memory=memory, trg_mask=None)
+# log_prob = F.log_softmax(logits, dim=2)
+
+# log_prob, indices = torch.topk(log_prob, beam_width)
+
+# for new_k in range(beam_width):
+# # TODO: continue from here
+# token_index = indices[0][new_k].view(1, -1)
+# log_p = log_prob[0][new_k].item()
+
+# node = Node()
+
+# pass
+
+# pass
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
new file mode 100644
index 0000000..1807bb9
--- /dev/null
+++ b/text_recognizer/networks/cnn.py
@@ -0,0 +1,101 @@
+"""Implementation of a simple backbone cnn network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class CNN(nn.Module):
+ """LeNet network for character prediction."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...] = (1, 32, 64, 128),
+ kernel_sizes: Tuple[int, ...] = (4, 4, 4),
+ strides: Tuple[int, ...] = (2, 2, 2),
+ max_pool_kernel: int = 2,
+ dropout_rate: float = 0.2,
+ activation: Optional[str] = "relu",
+ ) -> None:
+ """Initialization of the LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+ strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2).
+ max_pool_kernel (int): 2D max pooling kernel. Defaults to 2.
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ activation (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+ Raises:
+ RuntimeError: if the number of hyperparameters does not match in length.
+
+ """
+ super().__init__()
+
+ if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides):
+ raise RuntimeError("The number of the hyperparameters does not match.")
+
+ self.cnn = self._build_network(
+ channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation,
+ )
+
+ def _build_network(
+ self,
+ channels: Tuple[int, ...],
+ kernel_sizes: Tuple[int, ...],
+ strides: Tuple[int, ...],
+ max_pool_kernel: int,
+ dropout_rate: float,
+ activation: str,
+ ) -> nn.Sequential:
+ # Load activation function.
+ activation_fn = activation_function(activation)
+
+ channels = list(channels)
+ in_channels = channels.pop(0)
+ configuration = zip(channels, kernel_sizes, strides)
+
+ modules = nn.ModuleList([])
+
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ # Add max pool to reduce output size.
+ if i == len(channels) // 2:
+ modules.append(nn.MaxPool2d(max_pool_kernel))
+ if i == 0:
+ modules.append(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ )
+ )
+ else:
+ modules.append(
+ nn.Sequential(
+ activation_fn,
+ nn.BatchNorm2d(in_channels),
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ )
+ )
+
+ if dropout_rate:
+ modules.append(nn.Dropout2d(p=dropout_rate))
+
+ in_channels = out_channels
+
+ return nn.Sequential(*modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.cnn(x)
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
new file mode 100644
index 0000000..9150b55
--- /dev/null
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -0,0 +1,158 @@
+"""A CNN-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+
+
+class CNNTransformer(nn.Module):
+ """CNN+Transfomer for image to sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ pool_kernel: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ super().__init__()
+ self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
+ self.backbone = configure_backbone(backbone, backbone_args)
+
+ if pool_kernel is not None:
+ self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
+ else:
+ self.max_pool = None
+
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.pos_dropout = nn.Dropout(p=dropout_rate)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(
+ # nn.Linear(hidden_dim, hidden_dim * 2),
+ # activation_function(activation),
+ nn.Linear(hidden_dim, vocab_size),
+ )
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tensor:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+
+ src = self.backbone(src)
+
+ if self.max_pool is not None:
+ src = self.max_pool(src)
+
+ if self.adaptive_pool is not None and len(src.shape) == 4:
+ src = rearrange(src, "b c h w -> b w c h")
+ src = self.adaptive_pool(src)
+ src = src.squeeze(3)
+ elif len(src.shape) == 4:
+ src = rearrange(src, "b c h w -> b (h w) c")
+
+ b, t, _ = src.shape
+
+ src += self.src_position_embedding[:, :t]
+ src = self.pos_dropout(src)
+
+ return src
+
+ def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.trg_position_encoding(trg)
+ return trg
+
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ image_features = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
+ return logits
diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py
new file mode 100644
index 0000000..778e232
--- /dev/null
+++ b/text_recognizer/networks/crnn.py
@@ -0,0 +1,110 @@
+"""CRNN for handwritten text recognition."""
+from typing import Dict, Tuple
+
+from einops import rearrange, reduce
+from einops.layers.torch import Rearrange
+from loguru import logger
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import configure_backbone
+
+
+class ConvolutionalRecurrentNetwork(nn.Module):
+ """Network that takes a image of a text line and predicts tokens that are in the image."""
+
+ def __init__(
+ self,
+ backbone: str,
+ backbone_args: Dict = None,
+ input_size: int = 128,
+ hidden_size: int = 128,
+ bidirectional: bool = False,
+ num_layers: int = 1,
+ num_classes: int = 80,
+ patch_size: Tuple[int, int] = (28, 28),
+ stride: Tuple[int, int] = (1, 14),
+ recurrent_cell: str = "lstm",
+ avg_pool: bool = False,
+ use_sliding_window: bool = True,
+ ) -> None:
+ super().__init__()
+ self.backbone_args = backbone_args or {}
+ self.patch_size = patch_size
+ self.stride = stride
+ self.sliding_window = (
+ self._configure_sliding_window() if use_sliding_window else None
+ )
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.bidirectional = bidirectional
+ self.avg_pool = avg_pool
+
+ if recurrent_cell.upper() in ["LSTM", "GRU"]:
+ recurrent_cell = getattr(nn, recurrent_cell)
+ else:
+ logger.warning(
+ f"Option {recurrent_cell} not valid, defaulting to LSTM cell."
+ )
+ recurrent_cell = nn.LSTM
+
+ self.rnn = recurrent_cell(
+ input_size=self.input_size,
+ hidden_size=self.hidden_size,
+ bidirectional=bidirectional,
+ num_layers=num_layers,
+ )
+
+ decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size
+
+ self.decoder = nn.Sequential(
+ nn.Linear(in_features=decoder_size, out_features=num_classes),
+ nn.LogSoftmax(dim=2),
+ )
+
+ def _configure_sliding_window(self) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Unfold(kernel_size=self.patch_size, stride=self.stride),
+ Rearrange(
+ "b (c h w) t -> b t c h w",
+ h=self.patch_size[0],
+ w=self.patch_size[1],
+ c=1,
+ ),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+
+ if self.sliding_window is not None:
+ # Create image patches with a sliding window kernel.
+ x = self.sliding_window(x)
+
+ # Rearrange from a sequence of patches for feedforward network.
+ b, t = x.shape[:2]
+ x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t)
+
+ x = self.backbone(x)
+
+ # Average pooling.
+ if self.avg_pool:
+ x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
+ else:
+ x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
+ else:
+ # Encode the entire image with a CNN, and use the channels as temporal dimension.
+ x = self.backbone(x)
+ x = rearrange(x, "b c h w -> b w c h")
+ if self.adaptive_pool is not None:
+ x = self.adaptive_pool(x)
+ x = x.squeeze(3)
+
+ # Sequence predictions.
+ x, _ = self.rnn(x)
+
+ # Sequence to classification layer.
+ x = self.decoder(x)
+ return x
diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py
new file mode 100644
index 0000000..af9b700
--- /dev/null
+++ b/text_recognizer/networks/ctc.py
@@ -0,0 +1,58 @@
+"""Decodes the CTC output."""
+from typing import Callable, List, Optional, Tuple
+
+from einops import rearrange
+import torch
+from torch import Tensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+def greedy_decoder(
+ predictions: Tensor,
+ targets: Optional[Tensor] = None,
+ target_lengths: Optional[Tensor] = None,
+ character_mapper: Optional[Callable] = None,
+ blank_label: int = 79,
+ collapse_repeated: bool = True,
+) -> Tuple[List[str], List[str]]:
+ """Greedy CTC decoder.
+
+ Args:
+ predictions (Tensor): Tenor of network predictions, shape [time, batch, classes].
+ targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None.
+ target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None.
+ character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults
+ to None.
+ blank_label (int): The blank character to be ignored. Defaults to 80.
+ collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets.
+
+ """
+
+ if character_mapper is None:
+ character_mapper = EmnistMapper(pad_token="_") # noqa: S106
+
+ predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t")
+ decoded_predictions = []
+ decoded_targets = []
+ for i, prediction in enumerate(predictions):
+ decoded_prediction = []
+ decoded_target = []
+ if targets is not None and target_lengths is not None:
+ for target_index in targets[i][: target_lengths[i]]:
+ if target_index == blank_label:
+ continue
+ decoded_target.append(character_mapper(int(target_index)))
+ decoded_targets.append(decoded_target)
+ for j, index in enumerate(prediction):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == prediction[j - 1]:
+ continue
+ decoded_prediction.append(index.item())
+ decoded_predictions.append(
+ [character_mapper(int(pred_index)) for pred_index in decoded_prediction]
+ )
+ return decoded_predictions, decoded_targets
diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py
new file mode 100644
index 0000000..7dc58d9
--- /dev/null
+++ b/text_recognizer/networks/densenet.py
@@ -0,0 +1,225 @@
+"""Defines a Densely Connected Convolutional Networks in PyTorch.
+
+Sources:
+https://arxiv.org/abs/1608.06993
+https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
+
+"""
+from typing import List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _DenseLayer(nn.Module):
+ """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ growth_rate: int,
+ bn_size: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.dense_layer = [
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=bn_size * growth_rate,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(bn_size * growth_rate),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=bn_size * growth_rate,
+ out_channels=growth_rate,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if dropout_rate:
+ self.dense_layer.append(nn.Dropout(p=dropout_rate))
+
+ self.dense_layer = nn.Sequential(*self.dense_layer)
+
+ def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor:
+ if isinstance(x, list):
+ x = torch.cat(x, 1)
+ return self.dense_layer(x)
+
+
+class _DenseBlock(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.dense_block = self._build_dense_blocks(
+ num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation,
+ )
+
+ def _build_dense_blocks(
+ self,
+ num_layers: int,
+ in_channels: int,
+ bn_size: int,
+ growth_rate: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> nn.ModuleList:
+ dense_block = []
+ for i in range(num_layers):
+ dense_block.append(
+ _DenseLayer(
+ in_channels=in_channels + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ return nn.ModuleList(dense_block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ feature_maps = [x]
+ for layer in self.dense_block:
+ x = layer(feature_maps)
+ feature_maps.append(x)
+ return torch.cat(feature_maps, 1)
+
+
+class _Transition(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ activation_fn = activation_function(activation)
+ self.transition = nn.Sequential(
+ nn.BatchNorm2d(in_channels),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.transition(x)
+
+
+class DenseNet(nn.Module):
+ """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow."""
+
+ def __init__(
+ self,
+ growth_rate: int = 32,
+ block_config: List[int] = (6, 12, 24, 16),
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 80,
+ bn_size: int = 4,
+ dropout_rate: float = 0,
+ classifier: bool = True,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.densenet = self._configure_densenet(
+ in_channels,
+ base_channels,
+ num_classes,
+ growth_rate,
+ block_config,
+ bn_size,
+ dropout_rate,
+ classifier,
+ activation,
+ )
+
+ def _configure_densenet(
+ self,
+ in_channels: int,
+ base_channels: int,
+ num_classes: int,
+ growth_rate: int,
+ block_config: List[int],
+ bn_size: int,
+ dropout_rate: float,
+ classifier: bool,
+ activation: str,
+ ) -> nn.Sequential:
+ activation_fn = activation_function(activation)
+ densenet = [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=base_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(base_channels),
+ activation_fn,
+ ]
+
+ num_features = base_channels
+
+ for i, num_layers in enumerate(block_config):
+ densenet.append(
+ _DenseBlock(
+ num_layers=num_layers,
+ in_channels=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ dropout_rate=dropout_rate,
+ activation=activation,
+ )
+ )
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ densenet.append(
+ _Transition(
+ in_channels=num_features,
+ out_channels=num_features // 2,
+ activation=activation,
+ )
+ )
+ num_features = num_features // 2
+
+ densenet.append(activation_fn)
+
+ if classifier:
+ densenet.append(nn.AdaptiveAvgPool2d((1, 1)))
+ densenet.append(Rearrange("b c h w -> b (c h w)"))
+ densenet.append(
+ nn.Linear(in_features=num_features, out_features=num_classes)
+ )
+
+ return nn.Sequential(*densenet)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass of Densenet."""
+ # If batch dimenstion is missing, it will be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.densenet(x)
diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py
new file mode 100644
index 0000000..527e1a0
--- /dev/null
+++ b/text_recognizer/networks/lenet.py
@@ -0,0 +1,68 @@
+"""Implementation of the LeNet network."""
+from typing import Callable, Dict, Optional, Tuple
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class LeNet(nn.Module):
+ """LeNet network for character prediction."""
+
+ def __init__(
+ self,
+ channels: Tuple[int, ...] = (1, 32, 64),
+ kernel_sizes: Tuple[int, ...] = (3, 3, 2),
+ hidden_size: Tuple[int, ...] = (9216, 128),
+ dropout_rate: float = 0.2,
+ num_classes: int = 10,
+ activation_fn: Optional[str] = "relu",
+ ) -> None:
+ """Initialization of the LeNet network.
+
+ Args:
+ channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64).
+ kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2).
+ hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers.
+ Defaults to (9216, 128).
+ dropout_rate (float): The dropout rate. Defaults to 0.2.
+ num_classes (int): Number of classes. Defaults to 10.
+ activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu.
+
+ """
+ super().__init__()
+
+ activation_fn = activation_function(activation_fn)
+
+ self.layers = [
+ nn.Conv2d(
+ in_channels=channels[0],
+ out_channels=channels[1],
+ kernel_size=kernel_sizes[0],
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=channels[1],
+ out_channels=channels[2],
+ kernel_size=kernel_sizes[1],
+ ),
+ activation_fn,
+ nn.MaxPool2d(kernel_sizes[2]),
+ nn.Dropout(p=dropout_rate),
+ Rearrange("b c h w -> b (c h w)"),
+ nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]),
+ activation_fn,
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=hidden_size[1], out_features=num_classes),
+ ]
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.layers(x)
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
new file mode 100644
index 0000000..b489264
--- /dev/null
+++ b/text_recognizer/networks/loss/__init__.py
@@ -0,0 +1,2 @@
+"""Loss module."""
+from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py
new file mode 100644
index 0000000..cf9fa0d
--- /dev/null
+++ b/text_recognizer/networks/loss/loss.py
@@ -0,0 +1,69 @@
+"""Implementations of custom loss functions."""
+from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
+from torch import nn
+from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
+
+
+class EmbeddingLoss:
+ """Metric loss for training encoders to produce information-rich latent embeddings."""
+
+ def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
+ self.distance = distances.CosineSimilarity()
+ self.reducer = reducers.ThresholdReducer(low=0)
+ self.loss_fn = losses.TripletMarginLoss(
+ margin=margin, distance=self.distance, reducer=self.reducer
+ )
+ self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
+
+ def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
+ """Computes the metric loss for the embeddings based on their labels.
+
+ Args:
+ embeddings (Tensor): The laten vectors encoded by the network.
+ labels (Tensor): Labels of the embeddings.
+
+ Returns:
+ Tensor: The metric loss for the embeddings.
+
+ """
+ hard_pairs = self.miner(embeddings, labels)
+ loss = self.loss_fn(embeddings, labels, hard_pairs)
+ return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """Label smoothing loss function."""
+
+ def __init__(
+ self,
+ classes: int,
+ smoothing: float = 0.0,
+ ignore_index: int = None,
+ dim: int = -1,
+ ) -> None:
+ super().__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.cls = classes
+ self.dim = dim
+
+ def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+ """Calculates the loss."""
+ pred = pred.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ # true_dist = pred.data.clone()
+ true_dist = torch.zeros_like(pred)
+ true_dist.fill_(self.smoothing / (self.cls - 1))
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+ if self.ignore_index is not None:
+ true_dist[:, self.ignore_index] = 0
+ mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+ if mask.dim() > 0:
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py
new file mode 100644
index 0000000..2605731
--- /dev/null
+++ b/text_recognizer/networks/metrics.py
@@ -0,0 +1,123 @@
+"""Utility functions for models."""
+from typing import Optional
+
+from einops import rearrange
+import Levenshtein as Lev
+import torch
+from torch import Tensor
+
+from text_recognizer.networks import greedy_decoder
+
+
+def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float:
+ """Computes the accuracy.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ labels (Tensor): Ground truth labels.
+ pad_index (int): Padding index.
+
+ Returns:
+ float: The accuracy for the batch.
+
+ """
+
+ _, predicted = torch.max(outputs, dim=-1)
+
+ # Mask out the pad tokens
+ mask = labels != pad_index
+
+ predicted *= mask
+ labels *= mask
+
+ acc = (predicted == labels).sum().float() / labels.shape[0]
+ acc = acc.item()
+ return acc
+
+
+def cer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
+ """Computes the character error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+ batch_size (Optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
+
+ Returns:
+ float: The cer for the batch.
+
+ """
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths, blank_label=blank_label,
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+ prediction, target = (
+ prediction.replace(" ", ""),
+ target.replace(" ", ""),
+ )
+ lev_dist += Lev.distance(prediction, target)
+ return lev_dist / len(decoded_predictions)
+
+
+def wer(
+ outputs: Tensor,
+ targets: Tensor,
+ batch_size: Optional[int] = None,
+ blank_label: Optional[int] = int,
+) -> float:
+ """Computes the Word error rate.
+
+ Args:
+ outputs (Tensor): The output from the network.
+ targets (Tensor): Ground truth labels.
+ batch_size (optional[int]): Batch size if target and output has been flattend.
+ blank_label (Optional[int]): The blank character to be ignored. Defaults to 79.
+
+ Returns:
+ float: The wer for the batch.
+
+ """
+ if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None:
+ targets = rearrange(targets, "(b t) -> b t", b=batch_size)
+ outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size)
+
+ target_lengths = torch.full(
+ size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long,
+ )
+ decoded_predictions, decoded_targets = greedy_decoder(
+ outputs, targets, target_lengths, blank_label=blank_label,
+ )
+
+ lev_dist = 0
+
+ for prediction, target in zip(decoded_predictions, decoded_targets):
+ prediction = "".join(prediction)
+ target = "".join(target)
+
+ b = set(prediction.split() + target.split())
+ word2char = dict(zip(b, range(len(b))))
+
+ w1 = [chr(word2char[w]) for w in prediction.split()]
+ w2 = [chr(word2char[w]) for w in target.split()]
+
+ lev_dist += Lev.distance("".join(w1), "".join(w2))
+
+ return lev_dist / len(decoded_predictions)
diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py
new file mode 100644
index 0000000..1101912
--- /dev/null
+++ b/text_recognizer/networks/mlp.py
@@ -0,0 +1,73 @@
+"""Defines the MLP network."""
+from typing import Callable, Dict, List, Optional, Union
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+
+from text_recognizer.networks.util import activation_function
+
+
+class MLP(nn.Module):
+ """Multi layered perceptron network."""
+
+ def __init__(
+ self,
+ input_size: int = 784,
+ num_classes: int = 10,
+ hidden_size: Union[int, List] = 128,
+ num_layers: int = 3,
+ dropout_rate: float = 0.2,
+ activation_fn: str = "relu",
+ ) -> None:
+ """Initialization of the MLP network.
+
+ Args:
+ input_size (int): The input shape of the network. Defaults to 784.
+ num_classes (int): Number of classes in the dataset. Defaults to 10.
+ hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128.
+ num_layers (int): The number of hidden layers. Defaults to 3.
+ dropout_rate (float): The dropout rate at each layer. Defaults to 0.2.
+ activation_fn (str): Name of the activation function in the hidden layers. Defaults to
+ relu.
+
+ """
+ super().__init__()
+
+ activation_fn = activation_function(activation_fn)
+
+ if isinstance(hidden_size, int):
+ hidden_size = [hidden_size] * num_layers
+
+ self.layers = [
+ Rearrange("b c h w -> b (c h w)"),
+ nn.Linear(in_features=input_size, out_features=hidden_size[0]),
+ activation_fn,
+ ]
+
+ for i in range(num_layers - 1):
+ self.layers += [
+ nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]),
+ activation_fn,
+ ]
+
+ if dropout_rate:
+ self.layers.append(nn.Dropout(p=dropout_rate))
+
+ self.layers.append(
+ nn.Linear(in_features=hidden_size[-1], out_features=num_classes)
+ )
+
+ self.layers = nn.Sequential(*self.layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """The feedforward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ return self.layers(x)
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the network."""
+ return "mlp"
diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
new file mode 100644
index 0000000..c33f419
--- /dev/null
+++ b/text_recognizer/networks/residual_network.py
@@ -0,0 +1,310 @@
+"""Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Rearrange, Reduce
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class Conv2dAuto(nn.Conv2d):
+ """Convolution with auto padding based on kernel size."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
+
+
+def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential:
+ """3x3 convolution with batch norm."""
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,)
+ return nn.Sequential(
+ conv3x3(in_channels, out_channels, *args, **kwargs),
+ nn.BatchNorm2d(out_channels),
+ )
+
+
+class IdentityBlock(nn.Module):
+ """Residual with identity block."""
+
+ def __init__(
+ self, in_channels: int, out_channels: int, activation: str = "relu"
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.blocks = nn.Identity()
+ self.activation_fn = activation_function(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self.apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activation_fn(x)
+ return x
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.out_channels
+
+
+class ResidualBlock(IdentityBlock):
+ """Residual with nonlinear shortcut."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: int = 1,
+ downsampling: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ """Short summary.
+
+ Args:
+ in_channels (int): Number of in channels.
+ out_channels (int): umber of out channels.
+ expansion (int): Expansion factor of the out channels. Defaults to 1.
+ downsampling (int): Downsampling factor used in stride. Defaults to 1.
+ *args (type): Extra arguments.
+ **kwargs (type): Extra key value arguments.
+
+ """
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion = expansion
+ self.downsampling = downsampling
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ stride=self.downsampling,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.expanded_channels),
+ )
+ if self.apply_shortcut
+ else None
+ )
+
+ @property
+ def expanded_channels(self) -> int:
+ """Computes the expanded output channels."""
+ return self.out_channels * self.expansion
+
+ @property
+ def apply_shortcut(self) -> bool:
+ """Check if shortcut should be applied."""
+ return self.in_channels != self.expanded_channels
+
+
+class BasicBlock(ResidualBlock):
+ """Basic ResNet block."""
+
+ expansion = 1
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ bias=False,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ bias=False,
+ ),
+ )
+
+
+class BottleNeckBlock(ResidualBlock):
+ """Bottleneck block to increase depth while minimizing parameter size."""
+
+ expansion = 4
+
+ def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None:
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=self.downsampling,
+ ),
+ self.activation_fn,
+ conv_bn(
+ in_channels=self.out_channels,
+ out_channels=self.expanded_channels,
+ kernel_size=1,
+ ),
+ )
+
+
+class ResidualLayer(nn.Module):
+ """ResNet layer."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block: BasicBlock = BasicBlock,
+ num_blocks: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ downsampling = 2 if in_channels != out_channels else 1
+ self.blocks = nn.Sequential(
+ block(
+ in_channels, out_channels, *args, **kwargs, downsampling=downsampling
+ ),
+ *[
+ block(
+ out_channels * block.expansion,
+ out_channels,
+ downsampling=1,
+ *args,
+ **kwargs
+ )
+ for _ in range(num_blocks - 1)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.blocks(x)
+ return x
+
+
+class ResidualNetworkEncoder(nn.Module):
+ """Encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ block_sizes: Union[int, List[int]] = (32, 64),
+ depths: Union[int, List[int]] = (2, 2),
+ activation: str = "relu",
+ block: Type[nn.Module] = BasicBlock,
+ levels: int = 1,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ self.block_sizes = (
+ block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels
+ )
+ self.depths = depths if isinstance(depths, list) else [depths] * levels
+ self.activation = activation
+ self.gate = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.block_sizes[0],
+ kernel_size=7,
+ stride=2,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(self.block_sizes[0]),
+ activation_function(self.activation),
+ # nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+ )
+
+ self.blocks = self._configure_blocks(block)
+
+ def _configure_blocks(
+ self, block: Type[nn.Module], *args, **kwargs
+ ) -> nn.Sequential:
+ channels = [self.block_sizes[0]] + list(
+ zip(self.block_sizes, self.block_sizes[1:])
+ )
+ blocks = [
+ ResidualLayer(
+ in_channels=channels[0],
+ out_channels=channels[0],
+ num_blocks=self.depths[0],
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ ]
+ blocks += [
+ ResidualLayer(
+ in_channels=in_channels * block.expansion,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ block=block,
+ activation=self.activation,
+ *args,
+ **kwargs
+ )
+ for (in_channels, out_channels), num_blocks in zip(
+ channels[1:], self.depths[1:]
+ )
+ ]
+
+ return nn.Sequential(*blocks)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ # If batch dimenstion is missing, it needs to be added.
+ if len(x.shape) == 3:
+ x = x.unsqueeze(0)
+ x = self.gate(x)
+ x = self.blocks(x)
+ return x
+
+
+class ResidualNetworkDecoder(nn.Module):
+ """Classification head."""
+
+ def __init__(self, in_features: int, num_classes: int = 80) -> None:
+ super().__init__()
+ self.decoder = nn.Sequential(
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(in_features=in_features, out_features=num_classes),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.decoder(x)
+
+
+class ResidualNetwork(nn.Module):
+ """Full residual network."""
+
+ def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None:
+ super().__init__()
+ self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs)
+ self.decoder = ResidualNetworkDecoder(
+ in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels,
+ num_classes=num_classes,
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.encoder(x)
+ x = self.decoder(x)
+ return x
diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py
new file mode 100644
index 0000000..e9d216f
--- /dev/null
+++ b/text_recognizer/networks/stn.py
@@ -0,0 +1,44 @@
+"""Spatial Transformer Network."""
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class SpatialTransformerNetwork(nn.Module):
+ """A network with differentiable attention.
+
+ Network that learns how to perform spatial transformations on the input image in order to enhance the
+ geometric invariance of the model.
+
+ # TODO: add arguments to make it more general.
+
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ # Initialize the identity transformation and its weights and biases.
+ linear = nn.Linear(32, 3 * 2)
+ linear.weight.data.zero_()
+ linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
+
+ self.theta = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5),
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.ReLU(inplace=True),
+ Rearrange("b c h w -> b (c h w)", h=3, w=3),
+ nn.Linear(in_features=10 * 3 * 3, out_features=32),
+ nn.ReLU(inplace=True),
+ linear,
+ Rearrange("b (row col) -> b row col", row=2, col=3),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """The spatial transformation."""
+ grid = F.affine_grid(self.theta(x), x.shape)
+ return F.grid_sample(x, grid, align_corners=False)
diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..8c19a01
--- /dev/null
+++ b/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,3 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
+from .transducer import load_transducer_loss, Transducer
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..5fb8ba9
--- /dev/null
+++ b/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,208 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+ https://arxiv.org/abs/1904.02619
+ https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+ https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+ """Internal block of a 2D TDSC network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ img_depth: int,
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.img_depth = img_depth
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.fc_dim = in_channels * img_depth
+
+ # Network placeholders.
+ self.conv = None
+ self.mlp = None
+ self.instance_norm = None
+
+ self._build_block()
+
+ def _build_block(self) -> None:
+ # Convolutional block.
+ self.conv = nn.Sequential(
+ nn.Conv3d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+ padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # MLP block.
+ self.mlp = nn.Sequential(
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # Instance norm.
+ self.instance_norm = nn.ModuleList(
+ [
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, CD, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, CD, H, W = x.shape
+ C, D = self.in_channels, self.img_depth
+ residual = x
+ x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+ x = self.conv(x)
+ x = rearrange(x, "b c d h w -> b (c d) h w")
+ x += residual
+
+ x = self.instance_norm[0](x)
+
+ x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+ x + self.instance_norm[1](x)
+
+ # Output shape: [B, CD, H, W]
+ return x
+
+
+class TDS2d(nn.Module):
+ """TDS Netowrk.
+
+ Structure is the following:
+ Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ depth: int,
+ tds_groups: Tuple[int],
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ in_channels: int = 1,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.depth = depth
+ self.tds_groups = tds_groups
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+
+ self.tds = None
+ self.fc = None
+
+ self._build_network()
+
+ def _build_network(self) -> None:
+ in_channels = self.in_channels
+ modules = []
+ stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+ if self.input_dim % stride_h:
+ raise RuntimeError(
+ f"Image height not divisible by total stride {stride_h}."
+ )
+
+ for tds_group in self.tds_groups:
+ # Add downsample layer.
+ out_channels = self.depth * tds_group["channels"]
+ modules.extend(
+ [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=self.kernel_size,
+ padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ stride=tds_group["stride"],
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.InstanceNorm2d(out_channels, affine=True),
+ ]
+ )
+
+ for _ in range(tds_group["num_blocks"]):
+ modules.append(
+ TDSBlock2d(
+ tds_group["channels"],
+ self.depth,
+ self.kernel_size,
+ self.dropout_rate,
+ )
+ )
+
+ in_channels = out_channels
+
+ self.tds = nn.Sequential(*modules)
+ self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ if len(x.shape) == 4:
+ x = x.squeeze(1) # Squeeze the channel dim away.
+
+ B, H, W = x.shape
+ x = rearrange(
+ x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+ )
+ x = self.tds(x)
+
+ # x shape: [B, C, H, W]
+ x = rearrange(x, "b c h w -> b w (c h)")
+
+ return self.fc(x)
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+ def test_viterbi(self):
+ T = 5
+ N = 4
+ B = 2
+
+ # fmt: off
+ emissions1 = torch.tensor((
+ 0, 4, 0, 1,
+ 0, 2, 1, 1,
+ 0, 0, 0, 2,
+ 0, 0, 0, 2,
+ 8, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ emissions2 = torch.tensor((
+ 0, 2, 1, 7,
+ 0, 2, 9, 1,
+ 0, 0, 0, 2,
+ 0, 0, 5, 2,
+ 1, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ # fmt: on
+
+ # Test without blank:
+ labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+ transducer = Transducer(
+ tokens=["a", "b", "c", "d"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+ blank="none",
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+ # Test with blank without repeats:
+ labels = [[1, 0], [2, 2]]
+ transducer = Transducer(
+ tokens=["a", "b", "c"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+ blank="optional",
+ allow_repeats=False,
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
new file mode 100644
index 0000000..d7e3d08
--- /dev/null
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -0,0 +1,410 @@
+"""Transducer and the transducer loss function.py
+
+Stolen from:
+ https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
+
+"""
+from pathlib import Path
+import itertools
+from typing import Dict, List, Optional, Union, Tuple
+
+from loguru import logger
+import gtn
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+def make_scalar_graph(weight) -> gtn.Graph:
+ scalar = gtn.Graph()
+ scalar.add_node(True)
+ scalar.add_node(False, True)
+ scalar.add_arc(0, 1, 0, 0, weight)
+ return scalar
+
+
+def make_chain_graph(sequence) -> gtn.Graph:
+ graph = gtn.Graph(False)
+ graph.add_node(True)
+ for i, s in enumerate(sequence):
+ graph.add_node(False, i == (len(sequence) - 1))
+ graph.add_arc(i, i + 1, s)
+ return graph
+
+
+def make_transitions_graph(
+ ngram: int, num_tokens: int, calc_grad: bool = False
+) -> gtn.Graph:
+ transitions = gtn.Graph(calc_grad)
+ transitions.add_node(True, ngram == 1)
+
+ state_map = {(): 0}
+
+ # First build transitions which include <s>:
+ for n in range(1, ngram):
+ for state in itertools.product(range(num_tokens), repeat=n):
+ in_idx = state_map[state[:-1]]
+ out_idx = transitions.add_node(False, ngram == 1)
+ state_map[state] = out_idx
+ transitions.add_arc(in_idx, out_idx, state[-1])
+
+ for state in itertools.product(range(num_tokens), repeat=ngram):
+ state_idx = state_map[state[:-1]]
+ new_state_idx = state_map[state[1:]]
+ # p(state[-1] | state[:-1])
+ transitions.add_arc(state_idx, new_state_idx, state[-1])
+
+ if ngram > 1:
+ # Build transitions which include </s>:
+ end_idx = transitions.add_node(False, True)
+ for in_idx in range(end_idx):
+ transitions.add_arc(in_idx, end_idx, gtn.epsilon)
+
+ return transitions
+
+
+def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+ """Constructs a graph which transduces letters to word pieces."""
+ graph = gtn.Graph(False)
+ graph.add_node(True, True)
+ for i, wp in enumerate(word_pieces):
+ prev = 0
+ for l in wp[:-1]:
+ n = graph.add_node()
+ graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
+ prev = n
+ graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+ graph.arc_sort()
+ return graph
+
+
+def make_token_graph(
+ token_list: List, blank: str = "none", allow_repeats: bool = True
+) -> gtn.Graph:
+ """Constructs a graph with all the individual token transition models."""
+ if not allow_repeats and blank != "optional":
+ raise ValueError("Must use blank='optional' if disallowing repeats.")
+
+ ntoks = len(token_list)
+ graph = gtn.Graph(False)
+
+ # Creating nodes
+ graph.add_node(True, True)
+ for i in range(ntoks):
+ # We can consume one or more consecutive word
+ # pieces for each emission:
+ # E.g. [ab, ab, ab] transduces to [ab]
+ graph.add_node(False, blank != "forced")
+
+ if blank != "none":
+ graph.add_node()
+
+ # Creating arcs
+ if blank != "none":
+ # Blank index is assumed to be last (ntoks)
+ graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
+ graph.add_arc(ntoks + 1, 0, gtn.epsilon)
+
+ for i in range(ntoks):
+ graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
+ graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
+
+ if allow_repeats:
+ if blank == "forced":
+ # Allow transitions from token to blank only
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ else:
+ # Allow transition from token to blank and all other tokens
+ graph.add_arc(i + 1, 0, gtn.epsilon)
+
+ else:
+ # allow transitions to blank and all other tokens except the same token
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ for j in range(ntoks):
+ if i != j:
+ graph.add_arc(i + 1, j + 1, j, j)
+
+ return graph
+
+
+class TransducerLossFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ inputs,
+ targets,
+ tokens,
+ lexicon,
+ transition_params=None,
+ transitions=None,
+ reduction="none",
+ ) -> Tensor:
+ B, T, C = inputs.shape
+
+ losses = [None] * B
+ emissions_graphs = [None] * B
+
+ if transitions is not None:
+ if transition_params is None:
+ raise ValueError("Specified transitions, but not transition params.")
+
+ cpu_data = transition_params.cpu().contiguous()
+ transitions.set_weights(cpu_data.data_ptr())
+ transitions.calc_grad = transition_params.requires_grad
+ transitions.zero_grad()
+
+ def process(b: int) -> None:
+ # Create emission graph:
+ emissions = gtn.linear_graph(T, C, inputs.requires_grad)
+ cpu_data = inputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+ target = make_chain_graph(targets[b])
+ target.arc_sort(True)
+
+ # Create token tot grapheme decomposition graph
+ tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
+ tokens_target.arc_sort()
+
+ # Create alignment graph:
+ aligments = gtn.project_input(
+ gtn.remove(gtn.compose(tokens, tokens_target))
+ )
+ aligments.arc_sort()
+
+ # Add transitions scores:
+ if transitions is not None:
+ aligments = gtn.intersect(transitions, aligments)
+ aligments.arc_sort()
+
+ loss = gtn.forward_score(gtn.intersect(emissions, aligments))
+
+ # Normalize if needed:
+ if transitions is not None:
+ norm = gtn.forward_score(gtn.intersect(emissions, transitions))
+ loss = gtn.subtract(loss, norm)
+
+ losses[b] = gtn.negate(loss)
+
+ # Save for backward:
+ if emissions.calc_grad:
+ emissions_graphs[b] = emissions
+
+ gtn.parallel_for(process, range(B))
+
+ ctx.graphs = (losses, emissions_graphs, transitions)
+ ctx.input_shape = inputs.shape
+
+ # Optionally reduce by target length
+ if reduction == "mean":
+ scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
+ else:
+ scales = [1.0] * B
+
+ ctx.scales = scales
+
+ loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
+ return torch.mean(loss.to(inputs.device))
+
+ @staticmethod
+ def backward(ctx, grad_output) -> Tuple:
+ losses, emissions_graphs, transitions = ctx.graphs
+ scales = ctx.scales
+
+ B, T, C = ctx.input_shape
+ calc_emissions = ctx.needs_input_grad[0]
+ input_grad = torch.empty((B, T, C)) if calc_emissions else None
+
+ def process(b: int) -> None:
+ scale = make_scalar_graph(scales[b])
+ gtn.backward(losses[b], scale)
+ emissions = emissions_graphs[b]
+ if calc_emissions:
+ grad = emissions.grad().weights_to_numpy()
+ input_grad[b] = torch.tensor(grad).view(1, T, C)
+
+ gtn.parallel_for(process, range(B))
+
+ if calc_emissions:
+ input_grad = input_grad.to(grad_output.device)
+ input_grad *= grad_output / B
+
+ if ctx.needs_input_grad[4]:
+ grad = transitions.grad().weights_to_numpy()
+ transition_grad = torch.tensor(grad).to(grad_output.device)
+ transition_grad *= grad_output / B
+ else:
+ transition_grad = None
+
+ return (
+ input_grad,
+ None, # target
+ None, # tokens
+ None, # lexicon
+ transition_grad, # transition params
+ None, # transitions graph
+ None,
+ )
+
+
+TransducerLoss = TransducerLossFunction.apply
+
+
+class Transducer(nn.Module):
+ def __init__(
+ self,
+ tokens: List,
+ graphemes_to_idx: Dict,
+ ngram: int = 0,
+ transitions: str = None,
+ blank: str = "none",
+ allow_repeats: bool = True,
+ reduction: str = "none",
+ ) -> None:
+ """A generic transducer loss function.
+
+ Args:
+ tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
+ representing the output tokens of the model (e.g. letters,
+ word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
+ could be a list of sub-word tokens.
+ graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
+ "a", "b", ..) to their corresponding integer index.
+ ngram (int) : Order of the token-level transition model. If `ngram=0`
+ then no transition model is used.
+ blank (string) : Specifies the usage of blank token
+ 'none' - do not use blank token
+ 'optional' - allow an optional blank inbetween tokens
+ 'forced' - force a blank inbetween tokens (also referred to as garbage token)
+ allow_repeats (boolean) : If false, then we don't allow paths with
+ consecutive tokens in the alignment graph. This keeps the graph
+ unambiguous in the sense that the same input cannot transduce to
+ different outputs.
+ """
+ super().__init__()
+ if blank not in ["optional", "forced", "none"]:
+ raise ValueError(
+ "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
+ )
+ self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
+ self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+ self.ngram = ngram
+ if ngram > 0 and transitions is not None:
+ raise ValueError("Only one of ngram and transitions may be specified")
+
+ if ngram > 0:
+ transitions = make_transitions_graph(
+ ngram, len(tokens) + int(blank != "none"), True
+ )
+
+ if transitions is not None:
+ self.transitions = transitions
+ self.transitions.arc_sort()
+ self.transitions_params = nn.Parameter(
+ torch.zeros(self.transitions.num_arcs())
+ )
+ else:
+ self.transitions = None
+ self.transitions_params = None
+ self.reduction = reduction
+
+ def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
+ TransducerLoss(
+ inputs,
+ targets,
+ self.tokens,
+ self.lexicon,
+ self.transitions_params,
+ self.transitions,
+ self.reduction,
+ )
+
+ def viterbi(self, outputs: Tensor) -> List[Tensor]:
+ B, T, C = outputs.shape
+
+ if self.transitions is not None:
+ cpu_data = self.transition_params.cpu().contiguous()
+ self.transitions.set_weights(cpu_data.data_ptr())
+ self.transitions.calc_grad = False
+
+ self.tokens.arc_sort()
+
+ paths = [None] * B
+
+ def process(b: int) -> None:
+ emissions = gtn.linear_graph(T, C, False)
+ cpu_data = outputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+
+ if self.transitions is not None:
+ full_graph = gtn.intersect(emissions, self.transitions)
+ else:
+ full_graph = emissions
+
+ # Find the best path and remove back-off arcs:
+ path = gtn.remove(gtn.viterbi_path(full_graph))
+
+ # Left compose the viterbi path with the "aligment to token"
+ # transducer to get the outputs:
+ path = gtn.compose(path, self.tokens)
+
+ # When there are ambiguous paths (allow_repeats is true), we take
+ # the shortest:
+ path = gtn.viterbi_path(path)
+ path = gtn.remove(gtn.project_output(path))
+ paths[b] = path.labels_to_list()
+
+ gtn.parallel_for(process, range(B))
+ predictions = [torch.IntTensor(path) for path in paths]
+ return predictions
+
+
+def load_transducer_loss(
+ num_features: int,
+ ngram: int,
+ tokens: str,
+ lexicon: str,
+ transitions: str,
+ blank: str,
+ allow_repeats: bool,
+ prepend_wordsep: bool = False,
+ use_words: bool = False,
+ data_dir: Optional[Union[str, Path]] = None,
+ reduction: str = "mean",
+) -> Tuple[Transducer, int]:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if transitions is not None:
+ transitions = gtn.load(str(processed_path / transitions))
+
+ preprocessor = Preprocessor(
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ )
+
+ num_tokens = preprocessor.num_tokens
+
+ criterion = Transducer(
+ preprocessor.tokens,
+ preprocessor.graphemes_to_index,
+ ngram=ngram,
+ transitions=transitions,
+ blank=blank,
+ allow_repeats=allow_repeats,
+ reduction=reduction,
+ )
+
+ return criterion, num_tokens + int(blank != "none")
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
new file mode 100644
index 0000000..9febc88
--- /dev/null
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -0,0 +1,3 @@
+"""Transformer modules."""
+from .positional_encoding import PositionalEncoding
+from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
new file mode 100644
index 0000000..cce1ecc
--- /dev/null
+++ b/text_recognizer/networks/transformer/attention.py
@@ -0,0 +1,93 @@
+"""Implementes the attention module for the transformer."""
+from typing import Optional, Tuple
+
+from einops import rearrange
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """Implementation of multihead attention."""
+
+ def __init__(
+ self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.fc_q = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_k = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_v = nn.Linear(
+ in_features=hidden_dim, out_features=hidden_dim, bias=False
+ )
+ self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
+
+ self._init_weights()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(
+ self.fc_q.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_k.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.normal_(
+ self.fc_v.weight,
+ mean=0,
+ std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)),
+ )
+ nn.init.xavier_normal_(self.fc_out.weight)
+
+ def scaled_dot_product_attention(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ """Calculates the scaled dot product attention."""
+
+ # Compute the energy.
+ energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt(
+ query.shape[-1]
+ )
+
+ # If we have a mask for padding some inputs.
+ if mask is not None:
+ energy = energy.masked_fill(mask == 0, -np.inf)
+
+ # Compute the attention from the energy.
+ attention = torch.softmax(energy, dim=3)
+
+ out = torch.einsum("bhlt,bhtv->bhlv", [attention, value])
+ out = rearrange(out, "b head l v -> b l (head v)")
+ return out, attention
+
+ def forward(
+ self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ """Forward pass for computing the multihead attention."""
+ # Get the query, key, and value tensor.
+ query = rearrange(
+ self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads
+ )
+ key = rearrange(
+ self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads
+ )
+ value = rearrange(
+ self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads
+ )
+
+ out, attention = self.scaled_dot_product_attention(query, key, value, mask)
+
+ out = self.fc_out(out)
+ out = self.dropout(out)
+ return out, attention
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
new file mode 100644
index 0000000..1ba5537
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -0,0 +1,32 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+ """Encodes a sense of distance or time for transformer networks."""
+
+ def __init__(
+ self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+ ) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.max_len = max_len
+
+ pe = torch.zeros(max_len, hidden_dim)
+ position = torch.arange(0, max_len).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
+ )
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes the tensor with a postional embedding."""
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x)
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..dd180c4
--- /dev/null
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,264 @@
+"""Transfomer module."""
+import copy
+from typing import Dict, Optional, Type, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.transformer.attention import MultiHeadAttention
+from text_recognizer.networks.util import activation_function
+
+
+class GEGLU(nn.Module):
+ """GLU activation for improving feedforward activations."""
+
+ def __init__(self, dim_in: int, dim_out: int) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward propagation."""
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
+
+
+class _IntraLayerConnection(nn.Module):
+ """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
+
+ def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward(self, src: Tensor, residual: Tensor) -> Tensor:
+ return self.norm(self.dropout(src) + residual)
+
+
+class _ConvolutionalLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+
+ in_projection = (
+ nn.Sequential(
+ nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
+ )
+ if activation != "glu"
+ else GEGLU(hidden_dim, expansion_dim)
+ )
+
+ self.layer = nn.Sequential(
+ in_projection,
+ nn.Dropout(p=dropout_rate),
+ nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layer(x)
+
+
+class EncoderLayer(nn.Module):
+ """Transfomer encoding layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through the encoder."""
+ # First block.
+ # Multi head attention.
+ out, _ = self.self_attention(src, src, src, mask)
+
+ # Add & norm.
+ out = self.block1(out, src)
+
+ # Second block.
+ # Apply 1D-convolution.
+ cnn_out = self.cnn(out)
+
+ # Add & norm.
+ out = self.block2(cnn_out, out)
+
+ return out
+
+
+class Encoder(nn.Module):
+ """Transfomer encoder module."""
+
+ def __init__(
+ self,
+ num_layers: int,
+ encoder_layer: Type[nn.Module],
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.norm = norm
+
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ """Forward pass through all encoder layers."""
+ for layer in self.layers:
+ src = layer(src, src_mask)
+
+ if self.norm is not None:
+ src = self.norm(src)
+
+ return src
+
+
+class DecoderLayer(nn.Module):
+ """Transfomer decoder layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float = 0.0,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
+ self.multihead_attention = MultiHeadAttention(
+ hidden_dim, num_heads, dropout_rate
+ )
+ self.cnn = _ConvolutionalLayer(
+ hidden_dim, expansion_dim, dropout_rate, activation
+ )
+ self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
+ self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass of the layer."""
+ out, _ = self.self_attention(trg, trg, trg, trg_mask)
+ trg = self.block1(out, trg)
+
+ out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
+ trg = self.block2(out, trg)
+
+ out = self.cnn(trg)
+ out = self.block3(out, trg)
+
+ return out
+
+
+class Decoder(nn.Module):
+ """Transfomer decoder module."""
+
+ def __init__(
+ self,
+ decoder_layer: Type[nn.Module],
+ num_layers: int,
+ norm: Optional[Type[nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ trg: Tensor,
+ memory: Tensor,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the decoder."""
+ for layer in self.layers:
+ trg = layer(trg, memory, trg_mask, memory_mask)
+
+ if self.norm is not None:
+ trg = self.norm(trg)
+
+ return trg
+
+
+class Transformer(nn.Module):
+ """Transformer network."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ num_heads: int,
+ expansion_dim: int,
+ dropout_rate: float,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+
+ # Configure encoder.
+ encoder_norm = nn.LayerNorm(hidden_dim)
+ encoder_layer = EncoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
+
+ # Configure decoder.
+ decoder_norm = nn.LayerNorm(hidden_dim)
+ decoder_layer = DecoderLayer(
+ hidden_dim, num_heads, expansion_dim, dropout_rate, activation
+ )
+ self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src: Tensor,
+ trg: Tensor,
+ src_mask: Optional[Tensor] = None,
+ trg_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Forward pass through the transformer."""
+ if src.shape[0] != trg.shape[0]:
+ print(trg.shape)
+ raise RuntimeError("The batch size of the src and trg must be the same.")
+ if src.shape[2] != trg.shape[2]:
+ raise RuntimeError(
+ "The number of features for the src and trg must be the same."
+ )
+
+ memory = self.encoder(src, src_mask)
+ output = self.decoder(trg, memory, trg_mask, memory_mask)
+ return output
diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py
new file mode 100644
index 0000000..510910f
--- /dev/null
+++ b/text_recognizer/networks/unet.py
@@ -0,0 +1,255 @@
+"""UNet for segmentation."""
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+class _ConvBlock(nn.Module):
+ """Modified UNet convolutional block with dilation."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
+ ) -> None:
+ super().__init__()
+ self.channels = channels
+ self.dropout_rate = dropout_rate
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.padding = padding
+ self.num_groups = num_groups
+ self.activation = activation_function(activation)
+ self.block = self._configure_block()
+ self.residual_conv = nn.Sequential(
+ nn.Conv2d(
+ self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1
+ ),
+ self.activation,
+ )
+
+ def _configure_block(self) -> nn.Sequential:
+ block = []
+ for i in range(len(self.channels) - 1):
+ block += [
+ nn.Dropout(p=self.dropout_rate),
+ nn.GroupNorm(self.num_groups, self.channels[i]),
+ self.activation,
+ nn.Conv2d(
+ self.channels[i],
+ self.channels[i + 1],
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ stride=1,
+ dilation=self.dilation,
+ ),
+ ]
+
+ return nn.Sequential(*block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the convolutional block."""
+ residual = self.residual_conv(x)
+ return self.block(x) + residual
+
+
+class _DownSamplingBlock(nn.Module):
+ """Basic down sampling block."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ pooling_kernel: Union[int, bool] = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
+ ) -> None:
+ super().__init__()
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
+ self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Return the convolutional block output and a down sampled tensor."""
+ x = self.conv_block(x)
+ x_down = self.down_sampling(x) if self.down_sampling is not None else x
+
+ return x_down, x
+
+
+class _UpSamplingBlock(nn.Module):
+ """The upsampling block of the UNet."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ activation: str,
+ num_groups: int,
+ scale_factor: int = 2,
+ dropout_rate: float = 0.1,
+ kernel_size: int = 3,
+ dilation: int = 1,
+ padding: int = 0,
+ ) -> None:
+ super().__init__()
+ self.conv_block = _ConvBlock(
+ channels,
+ activation,
+ num_groups,
+ dropout_rate,
+ kernel_size,
+ dilation,
+ padding,
+ )
+ self.up_sampling = nn.Upsample(
+ scale_factor=scale_factor, mode="bilinear", align_corners=True
+ )
+
+ def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor:
+ """Apply the up sampling and convolutional block."""
+ x = self.up_sampling(x)
+ if x_skip is not None:
+ x = torch.cat((x, x_skip), dim=1)
+ return self.conv_block(x)
+
+
+class UNet(nn.Module):
+ """UNet architecture."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ base_channels: int = 64,
+ num_classes: int = 3,
+ depth: int = 4,
+ activation: str = "relu",
+ num_groups: int = 8,
+ dropout_rate: float = 0.1,
+ pooling_kernel: int = 2,
+ scale_factor: int = 2,
+ kernel_size: Optional[List[int]] = None,
+ dilation: Optional[List[int]] = None,
+ padding: Optional[List[int]] = None,
+ ) -> None:
+ super().__init__()
+ self.depth = depth
+ self.num_groups = num_groups
+
+ if kernel_size is not None and dilation is not None and padding is not None:
+ if (
+ len(kernel_size) != depth
+ and len(dilation) != depth
+ and len(padding) != depth
+ ):
+ raise RuntimeError(
+ "Length of convolutional parameters does not match the depth."
+ )
+ self.kernel_size = kernel_size
+ self.padding = padding
+ self.dilation = dilation
+
+ else:
+ self.kernel_size = [3] * depth
+ self.padding = [1] * depth
+ self.dilation = [1] * depth
+
+ self.dropout_rate = dropout_rate
+ self.conv = nn.Conv2d(
+ in_channels, base_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)]
+ self.encoder_blocks = self._configure_down_sampling_blocks(
+ channels, activation, pooling_kernel
+ )
+ self.decoder_blocks = self._configure_up_sampling_blocks(
+ channels, activation, scale_factor
+ )
+
+ self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
+
+ def _configure_down_sampling_blocks(
+ self, channels: List[int], activation: str, pooling_kernel: int
+ ) -> nn.ModuleList:
+ blocks = nn.ModuleList([])
+ for i in range(len(channels) - 1):
+ pooling_kernel = pooling_kernel if i < self.depth - 1 else False
+ dropout_rate = self.dropout_rate if i < 0 else 0
+ blocks += [
+ _DownSamplingBlock(
+ [channels[i], channels[i + 1], channels[i + 1]],
+ activation,
+ self.num_groups,
+ pooling_kernel,
+ dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
+ )
+ ]
+
+ return blocks
+
+ def _configure_up_sampling_blocks(
+ self, channels: List[int], activation: str, scale_factor: int,
+ ) -> nn.ModuleList:
+ channels.reverse()
+ self.kernel_size.reverse()
+ self.dilation.reverse()
+ self.padding.reverse()
+ return nn.ModuleList(
+ [
+ _UpSamplingBlock(
+ [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]],
+ activation,
+ self.num_groups,
+ scale_factor,
+ self.dropout_rate,
+ self.kernel_size[i],
+ self.dilation[i],
+ self.padding[i],
+ )
+ for i in range(len(channels) - 2)
+ ]
+ )
+
+ def _encode(self, x: Tensor) -> List[Tensor]:
+ x_skips = []
+ for block in self.encoder_blocks:
+ x, x_skip = block(x)
+ x_skips.append(x_skip)
+ return x_skips
+
+ def _decode(self, x_skips: List[Tensor]) -> Tensor:
+ x = x_skips[-1]
+ for i, block in enumerate(self.decoder_blocks):
+ x = block(x, x_skips[-(i + 2)])
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass with the UNet model."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ x = self.conv(x)
+ x_skips = self._encode(x)
+ x = self._decode(x_skips)
+ return self.head(x)
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
new file mode 100644
index 0000000..131a6b4
--- /dev/null
+++ b/text_recognizer/networks/util.py
@@ -0,0 +1,89 @@
+"""Miscellaneous neural network functionality."""
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
+
+from einops import rearrange
+from loguru import logger
+import torch
+from torch import nn
+
+
+def sliding_window(
+ images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+ """Creates patches of an image.
+
+ Args:
+ images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+ patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+ stride (Tuple[int, int]): The stride of the sliding window.
+
+ Returns:
+ torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+ """
+ unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
+ # Preform the sliding window, unsqueeze as the channel dimesion is lost.
+ c = images.shape[1]
+ patches = unfold(images)
+ patches = rearrange(
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
+ )
+ return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+ """Returns the callable activation function."""
+ activation_fns = nn.ModuleDict(
+ [
+ ["elu", nn.ELU(inplace=True)],
+ ["gelu", nn.GELU()],
+ ["glu", nn.GLU()],
+ ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+ ["none", nn.Identity()],
+ ["relu", nn.ReLU(inplace=True)],
+ ["selu", nn.SELU(inplace=True)],
+ ]
+ )
+ return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+ """Loads a backbone network."""
+ network_module = importlib.import_module("text_recognizer.networks")
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+ "pretrained"
+ )
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ freeze = False
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ backbone_args.pop("freeze")
+ freeze = True
+ network_args = backbone_args
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if freeze:
+ for params in backbone.parameters():
+ params.requires_grad = False
+ else:
+ backbone_ = getattr(network_module, backbone)
+ backbone = backbone_(**backbone_args)
+
+ if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+ backbone = nn.Sequential(
+ *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+ )
+
+ return backbone
diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py
new file mode 100644
index 0000000..efb3701
--- /dev/null
+++ b/text_recognizer/networks/vit.py
@@ -0,0 +1,150 @@
+"""A Vision Transformer.
+
+Inspired by:
+https://openreview.net/pdf?id=YicbFdNTTy
+
+"""
+from typing import Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import Transformer
+
+
+class ViT(nn.Module):
+ """Transfomer for image to sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ expansion_dim: int,
+ patch_dim: Tuple[int, int],
+ image_size: Tuple[int, int],
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ self.trg_pad_index = trg_pad_index
+ self.patch_dim = patch_dim
+ self.num_patches = image_size[-1] // self.patch_dim[1]
+
+ # Encoder
+ self.patch_to_embedding = nn.Linear(
+ self.patch_dim[0] * self.patch_dim[1], hidden_dim
+ )
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
+ self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.dropout = nn.Dropout(dropout_rate)
+ self._init()
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _init(self) -> None:
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+ # nn.init.normal_(self.pos_embedding.weight, std=0.02)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tensor:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: A input src to the transformer.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+
+ patches = rearrange(
+ src,
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+ p1=self.patch_dim[0],
+ p2=self.patch_dim[1],
+ )
+
+ # From patches to encoded sequence.
+ x = self.patch_to_embedding(patches)
+ b, n, _ = x.shape
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x += self.pos_embedding[:, : (n + 1)]
+ x = self.dropout(x)
+
+ return x
+
+ def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+
+ """
+ _, n = trg.shape
+ trg = self.character_embedding(trg.long())
+ trg += self.pos_embedding[:, :n]
+ return trg
+
+ def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(h, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ h = self.extract_image_features(x)
+ logits = self.decode_image_features(h, trg)
+ return logits
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..c673d96
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,150 @@
+"""A VQ-Transformer for image to text recognition."""
+from typing import Dict, Optional, Tuple
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.util import configure_backbone
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class VQTransformer(nn.Module):
+ """VQ+Transfomer for image to character sequence prediction."""
+
+ def __init__(
+ self,
+ num_encoder_layers: int,
+ num_decoder_layers: int,
+ hidden_dim: int,
+ vocab_size: int,
+ num_heads: int,
+ adaptive_pool_dim: Tuple,
+ expansion_dim: int,
+ dropout_rate: float,
+ trg_pad_index: int,
+ max_len: int,
+ backbone: str,
+ backbone_args: Optional[Dict] = None,
+ activation: str = "gelu",
+ ) -> None:
+ super().__init__()
+
+ # Configure vector quantized backbone.
+ self.backbone = configure_backbone(backbone, backbone_args)
+ self.conv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2),
+ nn.ReLU(inplace=True),
+ )
+
+ # Configure embeddings for Transformer network.
+ self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
+
+ self.adaptive_pool = (
+ nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ )
+
+ self.transformer = Transformer(
+ num_encoder_layers,
+ num_decoder_layers,
+ hidden_dim,
+ num_heads,
+ expansion_dim,
+ dropout_rate,
+ activation,
+ )
+
+ self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+
+ def _create_trg_mask(self, trg: Tensor) -> Tensor:
+ # Move this outside the transformer.
+ trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask
+
+ def encoder(self, src: Tensor) -> Tensor:
+ """Forward pass with the encoder of the transformer."""
+ return self.transformer.encoder(src)
+
+ def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
+ """Forward pass with the decoder of the transformer + classification head."""
+ return self.head(
+ self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ )
+
+ def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]:
+ """Extracts image features with a backbone neural network.
+
+ It seem like the winning idea was to swap channels and width dimension and collapse
+ the height dimension. The transformer is learning like a baby with this implementation!!! :D
+ Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+
+ Args:
+ src (Tensor): Input tensor.
+
+ Returns:
+ Tensor: The input src to the transformer and the vq loss.
+
+ """
+ # If batch dimension is missing, it needs to be added.
+ if len(src.shape) < 4:
+ src = src[(None,) * (4 - len(src.shape))]
+ src, vq_loss = self.backbone.encode(src)
+ # src = self.backbone.decoder.res_block(src)
+ src = self.conv(src)
+
+ if self.adaptive_pool is not None:
+ src = rearrange(src, "b c h w -> b w c h")
+ src = self.adaptive_pool(src)
+ src = src.squeeze(3)
+ else:
+ src = rearrange(src, "b c h w -> b (w h) c")
+
+ b, t, _ = src.shape
+
+ src += self.src_position_embedding[:, :t]
+
+ return src, vq_loss
+
+ def target_embedding(self, trg: Tensor) -> Tensor:
+ """Encodes target tensor with embedding and postion.
+
+ Args:
+ trg (Tensor): Target tensor.
+
+ Returns:
+ Tensor: Encoded target tensor.
+
+ """
+ trg = self.character_embedding(trg.long())
+ trg = self.trg_position_encoding(trg)
+ return trg
+
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
+ """Takes images features from the backbone and decodes them with the transformer."""
+ trg_mask = self._create_trg_mask(trg)
+ trg = self.target_embedding(trg)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
+
+ logits = self.head(out)
+ return logits
+
+ def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ """Forward pass with CNN transfomer."""
+ image_features, vq_loss = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
+ return logits, vq_loss
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
new file mode 100644
index 0000000..763953c
--- /dev/null
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -0,0 +1,5 @@
+"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
new file mode 100644
index 0000000..d3adac5
--- /dev/null
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -0,0 +1,147 @@
+"""CNN encoder for the VQ-VAE."""
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
+
+
+class _ResidualBlock(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ ) -> None:
+ super().__init__()
+ self.block = [
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
+ ]
+
+ if dropout is not None:
+ self.block.append(dropout)
+
+ self.block = nn.Sequential(*self.block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the residual forward pass."""
+ return x + self.block(x)
+
+
+class Encoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py
new file mode 100644
index 0000000..f92c7ee
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -0,0 +1,119 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+
+"""
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizer(nn.Module):
+ """The codebook that contains quantized vectors."""
+
+ def __init__(
+ self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+ ) -> None:
+ super().__init__()
+ self.K = num_embeddings
+ self.D = embedding_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.K, self.D)
+
+ # Initialize the codebook.
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
+
+ def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+ """Computes the code nearest to the latent representation.
+
+ First we compute the posterior categorical distribution, and then map
+ the latent representation to the nearest element of the embedding.
+
+ Args:
+ latent (Tensor): The latent representation.
+
+ Shape:
+ - latent :math:`(B x H x W, D)`
+
+ Returns:
+ Tensor: The quantized embedding vector.
+
+ """
+ # Store latent shape.
+ b, h, w, d = latent.shape
+
+ # Flatten the latent representation to 2D.
+ latent = rearrange(latent, "b h w d -> (b h w) d")
+
+ # Compute the L2 distance between the latents and the embeddings.
+ l2_distance = (
+ torch.sum(latent ** 2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight ** 2, dim=1)
+ - 2 * latent @ self.embedding.weight.t()
+ ) # [BHW x K]
+
+ # Find the embedding k nearest to each latent.
+ encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
+
+ # Convert to one-hot encodings, aka discrete bottleneck.
+ one_hot_encoding = torch.zeros(
+ encoding_indices.shape[0], self.K, device=latent.device
+ )
+ one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
+
+ # Embedding quantization.
+ quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
+ quantized_latent = rearrange(
+ quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
+ )
+
+ return quantized_latent
+
+ def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+ """Vector Quantization loss.
+
+ The vector quantization algorithm allows us to create a codebook. The VQ
+ algorithm works by moving the embedding vectors towards the encoder outputs.
+
+ The embedding loss moves the embedding vector towards the encoder outputs. The
+ .detach() works as the stop gradient (sg) described in the paper.
+
+ Because the volume of the embedding space is dimensionless, it can arbitarily
+ grow if the embeddings are not trained as fast as the encoder parameters. To
+ mitigate this, a commitment loss is added in the second term which makes sure
+ that the encoder commits to an embedding and that its output does not grow.
+
+ Args:
+ latent (Tensor): The encoder output.
+ quantized_latent (Tensor): The quantized latent.
+
+ Returns:
+ Tensor: The combinded VQ loss.
+
+ """
+ embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+ commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
+ return embedding_loss + self.beta * commitment_loss
+
+ def forward(self, latent: Tensor) -> Tensor:
+ """Forward pass that returns the quantized vector and the vq loss."""
+ # Rearrange latent representation s.t. the hidden dim is at the end.
+ latent = rearrange(latent, "b d h w -> b h w d")
+
+ # Maps latent to the nearest code in the codebook.
+ quantized_latent = self.discretization_bottleneck(latent)
+
+ loss = self.vq_loss(latent, quantized_latent)
+
+ # Add residue to the quantized latent.
+ quantized_latent = latent + (quantized_latent - latent).detach()
+
+ # Rearrange the quantized shape back to the original shape.
+ quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
+
+ return quantized_latent, loss
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss
diff --git a/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py
new file mode 100644
index 0000000..b767778
--- /dev/null
+++ b/text_recognizer/networks/wide_resnet.py
@@ -0,0 +1,221 @@
+"""Wide Residual CNN."""
+from functools import partial
+from typing import Callable, Dict, List, Optional, Type, Union
+
+from einops.layers.torch import Reduce
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """Helper function for a 3x3 2d convolution."""
+ return nn.Conv2d(
+ in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+
+
+def conv_init(module: Type[nn.Module]) -> None:
+ """Initializes the weights for convolution and batchnorms."""
+ classname = module.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2))
+ nn.init.constant_(module.bias, 0)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 0)
+
+
+class WideBlock(nn.Module):
+ """Block used in WideResNet."""
+
+ def __init__(
+ self,
+ in_planes: int,
+ out_planes: int,
+ dropout_rate: float,
+ stride: int = 1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.dropout_rate = dropout_rate
+ self.stride = stride
+ self.activation = activation_function(activation)
+
+ # Build blocks.
+ self.blocks = nn.Sequential(
+ nn.BatchNorm2d(self.in_planes),
+ self.activation,
+ conv3x3(in_planes=self.in_planes, out_planes=self.out_planes),
+ nn.Dropout(p=self.dropout_rate),
+ nn.BatchNorm2d(self.out_planes),
+ self.activation,
+ conv3x3(
+ in_planes=self.out_planes,
+ out_planes=self.out_planes,
+ stride=self.stride,
+ ),
+ )
+
+ self.shortcut = (
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_planes,
+ out_channels=self.out_planes,
+ kernel_size=1,
+ stride=self.stride,
+ bias=False,
+ ),
+ )
+ if self._apply_shortcut
+ else None
+ )
+
+ @property
+ def _apply_shortcut(self) -> bool:
+ """If shortcut should be applied or not."""
+ return self.stride != 1 or self.in_planes != self.out_planes
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ return x
+
+
+class WideResidualNetwork(nn.Module):
+ """WideResNet for character predictions.
+
+ Can be used for classification or encoding of images to a latent vector.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ in_planes: int = 16,
+ num_classes: int = 80,
+ depth: int = 16,
+ width_factor: int = 10,
+ dropout_rate: float = 0.0,
+ num_layers: int = 3,
+ block: Type[nn.Module] = WideBlock,
+ num_stages: Optional[List[int]] = None,
+ activation: str = "relu",
+ use_decoder: bool = True,
+ ) -> None:
+ """The initialization of the WideResNet.
+
+ Args:
+ in_channels (int): Number of input channels. Defaults to 1.
+ in_planes (int): Number of channels to use in the first output kernel. Defaults to 16.
+ num_classes (int): Number of classes. Defaults to 80.
+ depth (int): Set the number of blocks to use. Defaults to 16.
+ width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10.
+ dropout_rate (float): The dropout rate. Defaults to 0.0.
+ num_layers (int): Number of layers of blocks. Defaults to 3.
+ block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock.
+ num_stages (List[int]): If given, will use these channel values. Defaults to None.
+ activation (str): Name of the activation to use. Defaults to "relu".
+ use_decoder (bool): If True, the network output character predictions, if False, the network outputs a
+ latent vector. Defaults to True.
+
+ Raises:
+ RuntimeError: If the depth is not of the size `6n+4`.
+
+ """
+
+ super().__init__()
+ if (depth - 4) % 6 != 0:
+ raise RuntimeError("Wide-resnet depth should be 6n+4")
+ self.in_channels = in_channels
+ self.in_planes = in_planes
+ self.num_classes = num_classes
+ self.num_blocks = (depth - 4) // 6
+ self.width_factor = width_factor
+ self.num_layers = num_layers
+ self.block = block
+ self.dropout_rate = dropout_rate
+ self.activation = activation_function(activation)
+
+ if num_stages is None:
+ self.num_stages = [self.in_planes] + [
+ self.in_planes * 2 ** n * self.width_factor
+ for n in range(self.num_layers)
+ ]
+ else:
+ self.num_stages = [self.in_planes] + num_stages
+
+ self.num_stages = list(zip(self.num_stages, self.num_stages[1:]))
+ self.strides = [1] + [2] * (self.num_layers - 1)
+
+ self.encoder = nn.Sequential(
+ conv3x3(in_planes=self.in_channels, out_planes=self.in_planes),
+ *[
+ self._configure_wide_layer(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(
+ self.num_stages, self.strides
+ )
+ ],
+ )
+
+ self.decoder = (
+ nn.Sequential(
+ nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8),
+ self.activation,
+ Reduce("b c h w -> b c", "mean"),
+ nn.Linear(
+ in_features=self.num_stages[-1][-1], out_features=self.num_classes
+ ),
+ )
+ if use_decoder
+ else None
+ )
+
+ # self.apply(conv_init)
+
+ def _configure_wide_layer(
+ self, in_planes: int, out_planes: int, stride: int, activation: str
+ ) -> List:
+ strides = [stride] + [1] * (self.num_blocks - 1)
+ planes = [out_planes] * len(strides)
+ planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:]))
+ return nn.Sequential(
+ *[
+ self.block(
+ in_planes=in_planes,
+ out_planes=out_planes,
+ dropout_rate=self.dropout_rate,
+ stride=stride,
+ activation=activation,
+ )
+ for (in_planes, out_planes), stride in zip(planes, strides)
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Feedforward pass."""
+ if len(x.shape) < 4:
+ x = x[(None,) * int(4 - len(x.shape))]
+ x = self.encoder(x)
+ if self.decoder is not None:
+ x = self.decoder(x)
+ return x
diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py
new file mode 100644
index 0000000..aa39662
--- /dev/null
+++ b/text_recognizer/paragraph_text_recognizer.py
@@ -0,0 +1,153 @@
+"""Full model.
+
+Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the
+each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text
+in each region.
+"""
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import torch
+
+from text_recognizer.models import SegmentationModel, TransformerModel
+from text_recognizer.util import read_image
+
+
+class ParagraphTextRecognizor:
+ """Given an image of a single handwritten character, recognizes it."""
+
+ def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None:
+ self._line_predictor = TransformerModel(**line_predictor_args)
+ self._line_detector = SegmentationModel(**line_detector_args)
+ self._line_detector.eval()
+ self._line_predictor.eval()
+
+ def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple:
+ """Takes an image and returns all text within it."""
+ image = (
+ read_image(image_or_filename)
+ if isinstance(image_or_filename, str)
+ else image_or_filename
+ )
+
+ line_region_crops = self._get_line_region_crops(image)
+ processed_line_region_crops = [
+ self._process_image_for_line_predictor(image=crop)
+ for crop in line_region_crops
+ ]
+ line_region_strings = [
+ self.line_predictor_model.predict_on_image(crop)[0]
+ for crop in processed_line_region_crops
+ ]
+
+ return " ".join(line_region_strings), line_region_crops
+
+ def _get_line_region_crops(
+ self, image: np.ndarray, min_crop_len_factor: float = 0.02
+ ) -> List[np.ndarray]:
+ """Returns all the crops of text lines in a square image."""
+ processed_image, scale_down_factor = self._process_image_for_line_detector(
+ image
+ )
+ line_segmentation = self._line_detector.predict_on_image(processed_image)
+ bounding_boxes = _find_line_bounding_boxes(line_segmentation)
+
+ bounding_boxes = (bounding_boxes * scale_down_factor).astype(int)
+
+ min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1]))
+ line_region_crops = [
+ image[y : y + h, x : x + w]
+ for x, y, w, h in bounding_boxes
+ if w >= min_crop_len and h >= min_crop_len
+ ]
+ return line_region_crops
+
+ def _process_image_for_line_detector(
+ self, image: np.ndarray
+ ) -> Tuple[np.ndarray, float]:
+ """Convert uint8 image to float image with black background with shape self._line_detector.image_shape."""
+ resized_image, scale_down_factor = _resize_image_for_line_detector(
+ image=image, max_shape=self._line_detector.image_shape
+ )
+ resized_image = (1.0 - resized_image / 255).astype("float32")
+ return resized_image, scale_down_factor
+
+ def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray:
+ """Preprocessing of image before feeding it to the LinePrediction model.
+
+ Convert uint8 image to float image with black background with shape
+ self._line_predictor.image_shape while maintaining the image aspect ratio.
+
+ Args:
+ image (np.ndarray): Crop of text line.
+
+ Returns:
+ np.ndarray: Processed crop for feeding line predictor.
+ """
+ expected_shape = self._line_detector.image_shape
+ scale_factor = (np.array(expected_shape) / np.array(image.shape)).min()
+ scaled_image = cv2.resize(
+ image,
+ dsize=None,
+ fx=scale_factor,
+ fy=scale_factor,
+ interpolation=cv2.INTER_AREA,
+ )
+
+ pad_with = (
+ (0, expected_shape[0] - scaled_image.shape[0]),
+ (0, expected_shape[1] - scaled_image.shape[1]),
+ )
+
+ padded_image = np.pad(
+ scaled_image, pad_with=pad_with, mode="constant", constant_values=255
+ )
+ return 1 - padded_image / 255
+
+
+def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray:
+ """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels."""
+
+ def _find_line_bounding_boxes_in_channel(
+ line_segmentation_channel: np.ndarray,
+ ) -> np.ndarray:
+ line_segmentation_image = cv2.dilate(
+ line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1
+ )
+ line_activation_image = (line_segmentation_image * 255).astype("uint8")
+ line_activation_image = cv2.threshold(
+ line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU
+ )[1]
+
+ bounding_cnts, _ = cv2.findContours(
+ line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
+ return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts])
+
+ bounding_boxes = np.concatenate(
+ [
+ _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i])
+ for i in [1, 2]
+ ],
+ axis=0,
+ )
+
+ return bounding_boxes[np.argsort(bounding_boxes[:, 1])]
+
+
+def _resize_image_for_line_detector(
+ image: np.ndarray, max_shape: Tuple[int, int]
+) -> Tuple[np.ndarray, float]:
+ """Resize the image to less than the max_shape while maintaining the aspect ratio."""
+ scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape))
+ if scale_down_factor == 1:
+ return image.copy(), scale_down_factor
+ resize_image = cv2.resize(
+ image,
+ dsize=None,
+ fx=1 / scale_down_factor,
+ fy=1 / scale_down_factor,
+ interpolation=cv2.INTER_AREA,
+ )
+ return resize_image, scale_down_factor
diff --git a/text_recognizer/tests/__init__.py b/text_recognizer/tests/__init__.py
new file mode 100644
index 0000000..18ff212
--- /dev/null
+++ b/text_recognizer/tests/__init__.py
@@ -0,0 +1 @@
+"""Test modules for the text text recognizer."""
diff --git a/text_recognizer/tests/support/__init__.py b/text_recognizer/tests/support/__init__.py
new file mode 100644
index 0000000..a265ede
--- /dev/null
+++ b/text_recognizer/tests/support/__init__.py
@@ -0,0 +1,2 @@
+"""Support file modules."""
+from .create_emnist_support_files import create_emnist_support_files
diff --git a/text_recognizer/tests/support/create_emnist_lines_support_files.py b/text_recognizer/tests/support/create_emnist_lines_support_files.py
new file mode 100644
index 0000000..9abe143
--- /dev/null
+++ b/text_recognizer/tests/support/create_emnist_lines_support_files.py
@@ -0,0 +1,51 @@
+"""Module for creating EMNIST Lines test support files."""
+# flake8: noqa: S106
+
+from pathlib import Path
+import shutil
+
+import numpy as np
+
+from text_recognizer.datasets import EmnistLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ """Create EMNIST Lines test images."""
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = EmnistLinesDataset(
+ init_token="<sos>",
+ pad_token="_",
+ eos_token="<eos>",
+ transform=[{"type": "ToTensor", "args": {}}],
+ target_transform=[
+ {
+ "type": "AddTokens",
+ "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"},
+ }
+ ],
+ ) # nosec: S106
+ dataset.load_or_generate_data()
+
+ for index in [5, 7, 9]:
+ image, target = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
+ print(image.sum(), image.dtype)
+
+ label = "".join(dataset.mapper(label) for label in target[1:]).strip(
+ dataset.mapper.pad_token
+ )
+ print(label)
+ image = image.numpy()
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()
diff --git a/text_recognizer/tests/support/create_emnist_support_files.py b/text_recognizer/tests/support/create_emnist_support_files.py
new file mode 100644
index 0000000..f9ff030
--- /dev/null
+++ b/text_recognizer/tests/support/create_emnist_support_files.py
@@ -0,0 +1,30 @@
+"""Module for creating EMNIST test support files."""
+from pathlib import Path
+import shutil
+
+from text_recognizer.datasets import EmnistDataset
+from text_recognizer.util import write_image
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
+
+
+def create_emnist_support_files() -> None:
+ """Create support images for test of CharacterPredictor class."""
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ dataset = EmnistDataset(train=False)
+ dataset.load_or_generate_data()
+
+ for index in [5, 7, 9]:
+ image, label = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
+ image = image.numpy()
+ label = dataset.mapper(int(label))
+ print(index, label)
+ write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_support_files()
diff --git a/text_recognizer/tests/support/create_iam_lines_support_files.py b/text_recognizer/tests/support/create_iam_lines_support_files.py
new file mode 100644
index 0000000..50f9e3d
--- /dev/null
+++ b/text_recognizer/tests/support/create_iam_lines_support_files.py
@@ -0,0 +1,50 @@
+"""Module for creating IAM Lines test support files."""
+# flake8: noqa
+from pathlib import Path
+import shutil
+
+import numpy as np
+
+from text_recognizer.datasets import IamLinesDataset
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines"
+
+
+def create_emnist_lines_support_files() -> None:
+ """Create IAM Lines test images."""
+ shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
+ SUPPORT_DIRNAME.mkdir()
+
+ # TODO: maybe have to add args to dataset.
+ dataset = IamLinesDataset(
+ init_token="<sos>",
+ pad_token="_",
+ eos_token="<eos>",
+ transform=[{"type": "ToTensor", "args": {}}],
+ target_transform=[
+ {
+ "type": "AddTokens",
+ "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"},
+ }
+ ],
+ )
+ dataset.load_or_generate_data()
+
+ for index in [0, 1, 3]:
+ image, target = dataset[index]
+ if len(image.shape) == 3:
+ image = image.squeeze(0)
+ print(image.sum(), image.dtype)
+
+ label = "".join(dataset.mapper(label) for label in target[1:]).strip(
+ dataset.mapper.pad_token
+ )
+ print(label)
+ image = image.numpy()
+ util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
+
+
+if __name__ == "__main__":
+ create_emnist_lines_support_files()
diff --git a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
new file mode 100644
index 0000000..b7d0618
--- /dev/null
+++ b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
new file mode 100644
index 0000000..14a8cf3
--- /dev/null
+++ b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/emnist_lines/they<eos>.png b/text_recognizer/tests/support/emnist_lines/they<eos>.png
new file mode 100644
index 0000000..7f05951
--- /dev/null
+++ b/text_recognizer/tests/support/emnist_lines/they<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
new file mode 100644
index 0000000..6eeb642
--- /dev/null
+++ b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
new file mode 100644
index 0000000..4974cf8
--- /dev/null
+++ b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
new file mode 100644
index 0000000..a731245
--- /dev/null
+++ b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
Binary files differ
diff --git a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
new file mode 100644
index 0000000..d9753b6
--- /dev/null
+++ b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
Binary files differ
diff --git a/text_recognizer/tests/test_character_predictor.py b/text_recognizer/tests/test_character_predictor.py
new file mode 100644
index 0000000..01bda78
--- /dev/null
+++ b/text_recognizer/tests/test_character_predictor.py
@@ -0,0 +1,31 @@
+"""Test for CharacterPredictor class."""
+import importlib
+import os
+from pathlib import Path
+import unittest
+
+from loguru import logger
+
+from text_recognizer.character_predictor import CharacterPredictor
+from text_recognizer.networks import MLP
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestCharacterPredictor(unittest.TestCase):
+ """Tests for the CharacterPredictor class."""
+
+ def test_filename(self) -> None:
+ """Test that CharacterPredictor correctly predicts on a single image, for serveral test images."""
+ network_fn_ = MLP
+ predictor = CharacterPredictor(network_fn=network_fn_)
+
+ for filename in SUPPORT_DIRNAME.glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ logger.info(
+ f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}"
+ )
+ self.assertEqual(pred, filename.stem)
+ self.assertGreater(conf, 0.7)
diff --git a/text_recognizer/tests/test_line_predictor.py b/text_recognizer/tests/test_line_predictor.py
new file mode 100644
index 0000000..eede4d4
--- /dev/null
+++ b/text_recognizer/tests/test_line_predictor.py
@@ -0,0 +1,35 @@
+"""Tests for LinePredictor."""
+import os
+from pathlib import Path
+import unittest
+
+
+import editdistance
+import numpy as np
+
+from text_recognizer.datasets import IamLinesDataset
+from text_recognizer.line_predictor import LinePredictor
+import text_recognizer.util as util
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestEmnistLinePredictor(unittest.TestCase):
+ """Test LinePredictor class on the EmnistLines dataset."""
+
+ def test_filename(self) -> None:
+ """Test that LinePredictor correctly predicts on single images, for several test images."""
+ predictor = LinePredictor(
+ dataset="EmnistLineDataset", network_fn="CNNTransformer"
+ )
+
+ for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ true = str(filename.stem)
+ edit_distance = editdistance.eval(pred, true) / len(pred)
+ print(
+ f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}'
+ )
+ self.assertLess(edit_distance, 0.2)
diff --git a/text_recognizer/tests/test_paragraph_text_recognizer.py b/text_recognizer/tests/test_paragraph_text_recognizer.py
new file mode 100644
index 0000000..3e280b9
--- /dev/null
+++ b/text_recognizer/tests/test_paragraph_text_recognizer.py
@@ -0,0 +1,37 @@
+"""Test for ParagraphTextRecognizer class."""
+import os
+from pathlib import Path
+import unittest
+
+from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor
+import text_recognizer.util as util
+
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph"
+
+# Prevent using GPU.
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestParagraphTextRecognizor(unittest.TestCase):
+ """Test that it can take non-square images of max dimension larger than 256px."""
+
+ def test_filename(self) -> None:
+ """Test model on support image."""
+ line_predictor_args = {
+ "dataset": "EmnistLineDataset",
+ "network_fn": "CNNTransformer",
+ }
+ line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"}
+ model = ParagraphTextRecognizor(
+ line_predictor_args=line_predictor_args,
+ line_detector_args=line_detector_args,
+ )
+ num_text_lines_by_name = {"a01-000u-cropped": 7}
+ for filename in (SUPPORT_DIRNAME).glob("*.jpg"):
+ full_image = util.read_image(str(filename), grayscale=True)
+ predicted_text, line_region_crops = model.predict(full_image)
+ print(predicted_text)
+ self.assertTrue(
+ len(line_region_crops), num_text_lines_by_name[filename.stem]
+ )
diff --git a/text_recognizer/util.py b/text_recognizer/util.py
new file mode 100644
index 0000000..b431e22
--- /dev/null
+++ b/text_recognizer/util.py
@@ -0,0 +1,52 @@
+"""Utility functions for text_recognizer module."""
+import os
+from pathlib import Path
+from typing import Union
+from urllib.request import urlopen
+
+import cv2
+import numpy as np
+
+
+def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray:
+ """Read image_uri."""
+
+ def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray:
+ return cv2.imread(str(image_filename), imread_flag)
+
+ def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray:
+ if image_url.lower().startswith("http"):
+ url_response = urlopen(str(image_url))
+ image_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
+ return cv2.imdecode(image_array, imread_flag)
+ else:
+ raise ValueError(
+ "Url does not start with http, therefore not safe to open..."
+ ) from None
+
+ imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ local_file = os.path.exists(image_uri)
+ image = None
+
+ if local_file:
+ image = read_image_from_filename(image_uri, imread_flag)
+ else:
+ image = read_image_from_url(image_uri, imread_flag)
+
+ if image is None:
+ raise ValueError(f"Could not load image at {image_uri}")
+
+ return image
+
+
+def rescale_image(image: np.ndarray) -> np.ndarray:
+ """Rescale image from [0, 1] to [0, 255]."""
+ if image.max() <= 1.0:
+ image = 255 * (image - image.min()) / (image.max() - image.min())
+ return image
+
+
+def write_image(image: np.ndarray, filename: Union[Path, str]) -> None:
+ """Write image to file."""
+ image = rescale_image(image)
+ cv2.imwrite(str(filename), image)
diff --git a/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
new file mode 100644
index 0000000..344e0a3
--- /dev/null
+++ b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46d483950ef0876ba072d06cd94021e08d99c4fa14eeccf22aeae1cbb2066b4f
+size 5628749
diff --git a/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
new file mode 100644
index 0000000..f2dfd84
--- /dev/null
+++ b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a69e5efedea70c4c5cb8ccdcc8cd480400f6c73e3313423f4dbbfe615644f0a
+size 4500617
diff --git a/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
new file mode 100644
index 0000000..e1add8d
--- /dev/null
+++ b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68dd5c98eedc8753546f88b4e6fd5fc38725dc0079b837c30fb3d48069ec412b
+size 15002754
diff --git a/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
new file mode 100644
index 0000000..d9ca01d
--- /dev/null
+++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
Binary files differ
diff --git a/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
new file mode 100644
index 0000000..0af0e57
--- /dev/null
+++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
Binary files differ
diff --git a/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
new file mode 100644
index 0000000..b5295c2
--- /dev/null
+++ b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
Binary files differ