From dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 14:54:44 +0100 Subject: new updates --- src/text_recognizer/character_predictor.py | 7 +- src/text_recognizer/datasets/__init__.py | 4 +- src/text_recognizer/datasets/dataset.py | 22 +- src/text_recognizer/datasets/emnist_dataset.py | 6 +- .../datasets/emnist_lines_dataset.py | 51 +++-- src/text_recognizer/datasets/iam_lines_dataset.py | 6 + .../datasets/iam_paragraphs_dataset.py | 7 +- src/text_recognizer/datasets/transforms.py | 40 ++++ src/text_recognizer/datasets/util.py | 31 ++- src/text_recognizer/line_predictor.py | 28 +++ src/text_recognizer/models/__init__.py | 16 +- src/text_recognizer/models/base.py | 66 ++++-- src/text_recognizer/models/character_model.py | 4 +- src/text_recognizer/models/crnn_model.py | 119 ++++++++++ src/text_recognizer/models/line_ctc_model.py | 115 ---------- src/text_recognizer/models/metrics.py | 5 +- .../models/transformer_encoder_model.py | 111 ++++++++++ .../models/vision_transformer_model.py | 119 ++++++++++ src/text_recognizer/networks/__init__.py | 20 +- src/text_recognizer/networks/cnn_transformer.py | 135 ++++++++++++ .../networks/cnn_transformer_encoder.py | 73 +++++++ src/text_recognizer/networks/crnn.py | 108 +++++++++ src/text_recognizer/networks/ctc.py | 2 +- src/text_recognizer/networks/densenet.py | 225 +++++++++++++++++++ src/text_recognizer/networks/lenet.py | 6 +- src/text_recognizer/networks/line_lstm_ctc.py | 120 ---------- src/text_recognizer/networks/loss.py | 69 ++++++ src/text_recognizer/networks/losses.py | 31 --- src/text_recognizer/networks/misc.py | 45 ---- src/text_recognizer/networks/mlp.py | 6 +- src/text_recognizer/networks/residual_network.py | 6 +- src/text_recognizer/networks/sparse_mlp.py | 78 +++++++ src/text_recognizer/networks/transformer.py | 5 - .../networks/transformer/__init__.py | 3 + .../networks/transformer/attention.py | 93 ++++++++ .../networks/transformer/positional_encoding.py | 32 +++ .../networks/transformer/transformer.py | 242 +++++++++++++++++++++ src/text_recognizer/networks/util.py | 83 +++++++ src/text_recognizer/networks/vision_transformer.py | 159 ++++++++++++++ src/text_recognizer/networks/wide_resnet.py | 6 +- .../tests/support/create_emnist_support_files.py | 13 +- src/text_recognizer/tests/test_line_predictor.py | 35 +++ ...ataset_ConvolutionalRecurrentNetwork_weights.pt | Bin 0 -> 5628749 bytes ...haracterModel_EmnistDataset_DenseNet_weights.pt | Bin 0 -> 1273881 bytes .../CharacterModel_EmnistDataset_LeNet_weights.pt | Bin 14485362 -> 0 bytes .../CharacterModel_EmnistDataset_MLP_weights.pt | Bin 17938163 -> 0 bytes ...EmnistDataset_ResidualNetworkEncoder_weights.pt | Bin 26090486 -> 0 bytes ...rModel_EmnistDataset_ResidualNetwork_weights.pt | Bin 32765213 -> 0 bytes ...aracterModel_EmnistDataset_SpinalVGG_weights.pt | Bin 44089479 -> 0 bytes ...el_EmnistDataset_WideResidualNetwork_weights.pt | Bin 0 -> 14953410 bytes .../weights/CharacterModel_Emnist_LeNet_weights.pt | Bin 14485342 -> 0 bytes .../weights/CharacterModel_Emnist_MLP_weights.pt | Bin 1704096 -> 0 bytes ...IamLinesDataset_LineRecurrentNetwork_weights.pt | Bin 20694308 -> 3457858 bytes 53 files changed, 1956 insertions(+), 396 deletions(-) create mode 100644 src/text_recognizer/line_predictor.py create mode 100644 src/text_recognizer/models/crnn_model.py delete mode 100644 src/text_recognizer/models/line_ctc_model.py create mode 100644 src/text_recognizer/models/transformer_encoder_model.py create mode 100644 src/text_recognizer/models/vision_transformer_model.py create mode 100644 src/text_recognizer/networks/cnn_transformer.py create mode 100644 src/text_recognizer/networks/cnn_transformer_encoder.py create mode 100644 src/text_recognizer/networks/crnn.py create mode 100644 src/text_recognizer/networks/densenet.py delete mode 100644 src/text_recognizer/networks/line_lstm_ctc.py create mode 100644 src/text_recognizer/networks/loss.py delete mode 100644 src/text_recognizer/networks/losses.py delete mode 100644 src/text_recognizer/networks/misc.py create mode 100644 src/text_recognizer/networks/sparse_mlp.py delete mode 100644 src/text_recognizer/networks/transformer.py create mode 100644 src/text_recognizer/networks/transformer/__init__.py create mode 100644 src/text_recognizer/networks/transformer/attention.py create mode 100644 src/text_recognizer/networks/transformer/positional_encoding.py create mode 100644 src/text_recognizer/networks/transformer/transformer.py create mode 100644 src/text_recognizer/networks/util.py create mode 100644 src/text_recognizer/networks/vision_transformer.py create mode 100644 src/text_recognizer/tests/test_line_predictor.py create mode 100644 src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt create mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt create mode 100644 src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt delete mode 100644 src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index df37e68..ad71289 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple, Type, Union import numpy as np from torch import nn +from text_recognizer import datasets, networks from text_recognizer.models import CharacterModel from text_recognizer.util import read_image @@ -11,9 +12,11 @@ from text_recognizer.util import read_image class CharacterPredictor: """Recognizes the character in handwritten character images.""" - def __init__(self, network_fn: Type[nn.Module]) -> None: + def __init__(self, network_fn: str, dataset: str) -> None: """Intializes the CharacterModel and load the pretrained weights.""" - self.model = CharacterModel(network_fn=network_fn) + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = CharacterModel(network_fn=network_fn, dataset=dataset) self.model.eval() self.model.use_swa_model() diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index a3af9b1..d8372e3 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,5 +1,5 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset, Transpose +from .emnist_dataset import EmnistDataset from .emnist_lines_dataset import ( construct_image_from_string, EmnistLinesDataset, @@ -8,6 +8,7 @@ from .emnist_lines_dataset import ( from .iam_dataset import IamDataset from .iam_lines_dataset import IamLinesDataset from .iam_paragraphs_dataset import IamParagraphsDataset +from .transforms import AddTokens, Transpose from .util import ( _download_raw_dataset, compute_sha256, @@ -19,6 +20,7 @@ from .util import ( __all__ = [ "_download_raw_dataset", + "AddTokens", "compute_sha256", "construct_image_from_string", "DATA_DIRNAME", diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 05520e5..2de7f09 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -18,6 +18,9 @@ class Dataset(data.Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Initialization of Dataset class. @@ -26,12 +29,14 @@ class Dataset(data.Dataset): subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). """ - self.train = train self.split = "train" if self.train else "test" @@ -40,19 +45,18 @@ class Dataset(data.Dataset): raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction - self._mapper = EmnistMapper() + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token + ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + self.transform = transform if transform is not None else ToTensor() + self.target_transform = ( + target_transform if target_transform is not None else torch.tensor + ) self._data = None self._targets = None diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index d01dcee..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,8 +56,7 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) - # The EMNIST dataset is already casted to tensors. - self.target_transform = target_transform + self.target_transform = None self.seed = seed diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6268a01..6871492 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -4,6 +4,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union +import click import h5py from loguru import logger import numpy as np @@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset): max_overlap: float = 0.33, num_samples: int = 10000, seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Set attributes and loads the dataset. @@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset): max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. num_samples (int): Number of samples to generate. Defaults to 10000. seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=self.pad_token, + eos_token=eos_token, ) # Extract dataset information. @@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset): @property def data_filename(self) -> Path: """Path to the h5 file.""" - filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" - if self.train: - filename = "train_" + filename - else: - filename = "test_" + filename + filename = "train.pt" if self.train else "test.pt" return DATA_DIRNAME / filename def load_or_generate_data(self) -> None: @@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][:] - self._targets = f["targets"][:] + self._data = f["data"][()] + self._targets = f["targets"][()] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" @@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) + emnist.load_or_generate_data() samples_by_character = get_samples_by_character( emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, @@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels( return np.array([[mapping[c] for c in label] for label in labels]) +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") def create_datasets( max_length: int = 34, min_overlap: float = 0, @@ -306,17 +329,17 @@ def create_datasets( num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" - emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_test = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_test] num_samples = [num_train, num_test] - for num, train, dataset in zip(num_samples, [True, False], datasets): + for num, train in zip(num_samples, [True, False]): emnist_lines = EmnistLinesDataset( train=train, - emnist=dataset, max_length=max_length, min_overlap=min_overlap, max_overlap=max_overlap, num_samples=num, ) - emnist_lines._load_or_generate_data() + emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 4a74b2b..fdd2fe6 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -32,12 +32,18 @@ class IamLinesDataset(Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: super().__init__( train=train, subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, ) @property diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index 4b34bd1..c1e8fe2 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None: @click.option( "--subsample_fraction", type=float, - default=0.0, + default=None, help="The subsampling factor of the dataset.", ) def main(subsample_fraction: float) -> None: """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) dataset.load_or_generate_data() print(dataset) diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 17231a8..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,6 +3,9 @@ import numpy as np from PIL import Image import torch from torch import Tensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor + +from text_recognizer.datasets.util import EmnistMapper class Transpose: @@ -11,3 +14,40 @@ class Transpose: def __call__(self, image: Image) -> np.ndarray: """Swaps axis.""" return np.array(image).swapaxes(0, 1) + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 73968a1..d2df8b5 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -4,6 +4,7 @@ import importlib import json import os from pathlib import Path +import string from typing import Callable, Dict, List, Optional, Type, Union from urllib.request import urlopen, urlretrieve @@ -26,7 +27,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: mapping = [(i, str(label)) for i, label in enumerate(labels)] essentials = { "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), } logger.info("Saving emnist essentials...") with open(ESSENTIALS_FILENAME, "w") as f: @@ -43,11 +44,21 @@ def download_emnist() -> None: class EmnistMapper: """Mapper between network output to Emnist character.""" - def __init__(self) -> None: + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.essentials = self._load_emnist_essentials() # Load dataset infromation. - self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] @@ -103,7 +114,7 @@ class EmnistMapper: essentials = json.load(f) return essentials - def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset extra_symbols = [ @@ -127,14 +138,20 @@ class EmnistMapper: ] # padding symbol, and acts as blank symbol as well. - extra_symbols.append("_") + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) - max_key = max(mapping.keys()) + max_key = max(self.mapping.keys()) extra_mapping = {} for i, symbol in enumerate(extra_symbols): extra_mapping[max_key + 1 + i] = symbol - return {**mapping, **extra_mapping} + self._mapping = {**self.mapping, **extra_mapping} def compute_sha256(filename: Union[Path, str]) -> str: diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py new file mode 100644 index 0000000..981e2c9 --- /dev/null +++ b/src/text_recognizer/line_predictor.py @@ -0,0 +1,28 @@ +"""LinePredictor class.""" +import importlib +from typing import Tuple, Union + +import numpy as np +from torch import nn + +from text_recognizer import datasets, networks +from text_recognizer.models import VisionTransformerModel +from text_recognizer.util import read_image + + +class LinePredictor: + """Given an image of a line of handwritten text, recognizes the text content.""" + + def __init__(self, dataset: str, network_fn: str) -> None: + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = VisionTransformerModel(network_fn=network_fn, dataset=dataset) + self.model.eval() + + def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: + """Predict on a single images contianing a handwritten character.""" + if isinstance(image_or_filename, str): + image = read_image(image_or_filename, grayscale=True) + else: + image = image_or_filename + return self.model.predict_on_image(image) diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py index a3cfc15..28aa52e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/src/text_recognizer/models/__init__.py @@ -1,7 +1,19 @@ """Model modules.""" from .base import Model from .character_model import CharacterModel -from .line_ctc_model import LineCTCModel +from .crnn_model import CRNNModel from .metrics import accuracy, cer, wer +from .transformer_encoder_model import TransformerEncoderModel +from .vision_transformer_model import VisionTransformerModel -__all__ = ["Model", "cer", "CharacterModel", "LineCTCModel", "accuracy", "wer"] +__all__ = [ + "Model", + "cer", + "CharacterModel", + "CRNNModel", + "CNNTransfromerModel", + "accuracy", + "TransformerEncoderModel", + "VisionTransformerModel", + "wer", +] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index caf8065..cc44c92 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import importlib from pathlib import Path import re import shutil -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import torch @@ -15,6 +15,7 @@ from torch import Tensor from torch.optim.swa_utils import AveragedModel, SWALR from torch.utils.data import DataLoader, Dataset, random_split from torchsummary import summary +from torchvision.transforms import Compose from text_recognizer.datasets import EmnistMapper @@ -128,16 +129,42 @@ class Model(ABC): self._configure_criterion() self._configure_optimizers() - # Prints a summary of the network in terminal. - self.summary() - # Set this flag to true to prevent the model from configuring again. self.is_configured = True + def _configure_transforms(self) -> None: + # Load transforms. + transforms_module = importlib.import_module( + "text_recognizer.datasets.transforms" + ) + if ( + "transform" in self.dataset_args["args"] + and self.dataset_args["args"]["transform"] is not None + ): + transform_ = [] + for t in self.dataset_args["args"]["transform"]: + args = t["args"] or {} + transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["transform"] = Compose(transform_) + + if ( + "target_transform" in self.dataset_args["args"] + and self.dataset_args["args"]["target_transform"] is not None + ): + target_transform_ = [ + torch.tensor, + ] + for t in self.dataset_args["args"]["target_transform"]: + args = t["args"] or {} + target_transform_.append(getattr(transforms_module, t["type"])(**args)) + self.dataset_args["args"]["target_transform"] = Compose(target_transform_) + def prepare_data(self) -> None: """Prepare data for training.""" # TODO add downloading. if not self.data_prepared: + self._configure_transforms() + # Load train dataset. train_dataset = self.dataset(train=True, **self.dataset_args["args"]) train_dataset.load_or_generate_data() @@ -327,20 +354,20 @@ class Model(ABC): else: return self.network(x) - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: - """Compute the loss.""" - return self.criterion(output, targets) - def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 + self, + input_shape: Optional[Union[List, Tuple]] = None, + depth: int = 4, + device: Optional[str] = None, ) -> None: """Prints a summary of the network architecture.""" + device = self.device if device is None else device if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=device) else: logger.warning("Could not print summary as input shape is not set.") @@ -356,25 +383,29 @@ class Model(ABC): state["optimizer_state"] = self._optimizer.state_dict() if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler.state_dict() + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] if self._swa_network is not None: state["swa_network"] = self._swa_network.state_dict() return state - def load_from_checkpoint(self, checkpoint_path: Path) -> None: + def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: """Load a previously saved checkpoint. Args: checkpoint_path (Path): Path to the experiment with the checkpoint. """ + checkpoint_path = Path(checkpoint_path) + self.prepare_data() + self.configure_model() logger.debug("Loading checkpoint...") if not checkpoint_path.exists(): logger.debug("File does not exist {str(checkpoint_path)}") - checkpoint = torch.load(str(checkpoint_path)) + checkpoint = torch.load(str(checkpoint_path), map_location=self.device) self._network.load_state_dict(checkpoint["model_state"]) if self._optimizer is not None: @@ -383,8 +414,11 @@ class Model(ABC): if self._lr_scheduler is not None: # Does not work when loadning from previous checkpoint and trying to train beyond the last max epochs # with OneCycleLR. - if self._lr_scheduler.__class__.__name__ != "OneCycleLR": - self._lr_scheduler.load_state_dict(checkpoint["scheduler_state"]) + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] if self._swa_network is not None: self._swa_network.load_state_dict(checkpoint["swa_network"]) diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 50e94a2..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -47,8 +47,9 @@ class CharacterModel(Model): swa_args, device, ) + self.pad_token = dataset_args["args"]["pad_token"] if self._mapper is None: - self._mapper = EmnistMapper() + self._mapper = EmnistMapper(pad_token=self.pad_token,) self.tensor_transform = ToTensor() self.softmax = nn.Softmax(dim=0) @@ -65,6 +66,7 @@ class CharacterModel(Model): Tuple[str, float]: The predicted character and the confidence in the prediction. """ + self.eval() if image.dtype == np.uint8: # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. diff --git a/src/text_recognizer/models/crnn_model.py b/src/text_recognizer/models/crnn_model.py new file mode 100644 index 0000000..1e01a83 --- /dev/null +++ b/src/text_recognizer/models/crnn_model.py @@ -0,0 +1,119 @@ +"""Defines the CRNNModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CRNNModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py deleted file mode 100644 index 16eaed3..0000000 --- a/src/text_recognizer/models/line_ctc_model.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Defines the LineCTCModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class LineCTCModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - if self._mapper is None: - self._mapper = EmnistMapper() - self.tensor_transform = ToTensor() - - def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 79] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self.criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=79, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = torch.exp(log_probs.sum()).item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py index 6a26216..42c3c6e 100644 --- a/src/text_recognizer/models/metrics.py +++ b/src/text_recognizer/models/metrics.py @@ -17,7 +17,10 @@ def accuracy(outputs: Tensor, labels: Tensor) -> float: float: The accuracy for the batch. """ - _, predicted = torch.max(outputs.data, dim=1) + # eos_index = torch.nonzero(labels == eos, as_tuple=False) + # eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + _, predicted = torch.max(outputs, dim=-1) acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc diff --git a/src/text_recognizer/models/transformer_encoder_model.py b/src/text_recognizer/models/transformer_encoder_model.py new file mode 100644 index 0000000..e35e298 --- /dev/null +++ b/src/text_recognizer/models/transformer_encoder_model.py @@ -0,0 +1,111 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class TransformerEncoderModel(Model): + """A class for only using the encoder part in the sequence modelling.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + # self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 128 + + if self._mapper is None: + self._mapper = EmnistMapper( + # init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + logits = self.network(image) + # Convert logits to probabilities. + probs = self.softmax(logits).squeeze(0) + + confidence, pred_tokens = probs.max(1) + pred_tokens = pred_tokens + + eos_index = torch.nonzero( + pred_tokens == self._mapper(self.eos_token), as_tuple=False, + ) + + eos_index = eos_index[0].item() if eos_index.nelement() else -1 + + predicted_characters = "".join( + [self.mapper(x) for x in pred_tokens[:eos_index].tolist()] + ) + + confidence = np.min(confidence.tolist()) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vision_transformer_model.py b/src/text_recognizer/models/vision_transformer_model.py new file mode 100644 index 0000000..3d36437 --- /dev/null +++ b/src/text_recognizer/models/vision_transformer_model.py @@ -0,0 +1,119 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class VisionTransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + if network_args is not None: + self.max_len = network_args["max_len"] + else: + self.max_len = 120 + + if self._mapper is None: + self._mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.tensor_transform = ToTensor() + + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + src = self.network.preprocess_input(image) + memory = self.network.encoder(src) + + confidence_of_predictions = [] + trg_indices = [self.mapper(self.init_token)] + + for _ in range(self.max_len - 1): + trg = torch.tensor(trg_indices, device=self.device)[None, :].long() + trg = self.network.preprocess_target(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) + + # Convert logits to probabilities. + probs = self.softmax(logits) + + pred_token = probs.argmax(2)[:, -1].item() + confidence = probs.max(2).values[:, -1].item() + + trg_indices.append(pred_token) + confidence_of_predictions.append(confidence) + + if pred_token == self.mapper(self.eos_token): + break + + confidence = np.min(confidence_of_predictions) + predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index a39975f..6d88768 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -1,21 +1,33 @@ """Network modules.""" +from .cnn_transformer import CNNTransformer +from .cnn_transformer_encoder import CNNTransformerEncoder +from .crnn import ConvolutionalRecurrentNetwork from .ctc import greedy_decoder +from .densenet import DenseNet from .lenet import LeNet -from .line_lstm_ctc import LineRecurrentNetwork -from .losses import EmbeddingLoss -from .misc import sliding_window +from .loss import EmbeddingLoss from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .sparse_mlp import SparseMLP +from .transformer import Transformer +from .util import sliding_window +from .vision_transformer import VisionTransformer from .wide_resnet import WideResidualNetwork __all__ = [ + "CNNTransformer", + "CNNTransformerEncoder", + "ConvolutionalRecurrentNetwork", + "DenseNet", "EmbeddingLoss", "greedy_decoder", "MLP", "LeNet", - "LineRecurrentNetwork", "ResidualNetwork", "ResidualNetworkEncoder", "sliding_window", + "Transformer", + "SparseMLP", + "VisionTransformer", "WideResidualNetwork", ] diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py new file mode 100644 index 0000000..3da2c9f --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -0,0 +1,135 @@ +"""A DETR style transfomers but for text recognition.""" +from typing import Dict, Optional, Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformer(nn.Module): + """CNN+Transfomer for image to sequence prediction, sort of based on the ideas from DETR.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + adaptive_pool_dim: Tuple, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + backbone: str, + out_channels: int, + max_len: int, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + ) -> None: + super().__init__() + self.trg_pad_index = trg_pad_index + + self.backbone = configure_backbone(backbone, backbone_args) + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + + # self.conv = nn.Conv2d(out_channels, max_len, kernel_size=1) + + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.row_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + self.col_embed = nn.Parameter(torch.rand(max_len, max_len // 2)) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + ) + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def preprocess_input(self, src: Tensor) -> Tensor: + """Encodes src with a backbone network and a positional encoding. + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimenstion is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src = self.backbone(src) + # src = self.conv(src) + if self.adaptive_pool is not None: + src = self.adaptive_pool(src) + H, W = src.shape[-2:] + src = rearrange(src, "b t h w -> b t (h w)") + + # construct positional encodings + pos = torch.cat( + [ + self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), + self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), + ], + dim=-1, + ).unsqueeze(0) + pos = rearrange(pos, "b h w l -> b l (h w)") + src = pos + 0.1 * src + return src + + def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + trg = self.character_embedding(trg.long()) + trg = self.position_encoding(trg) + return trg + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.preprocess_input(x) + trg_mask = self._create_trg_mask(trg) + trg = self.preprocess_target(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits diff --git a/src/text_recognizer/networks/cnn_transformer_encoder.py b/src/text_recognizer/networks/cnn_transformer_encoder.py new file mode 100644 index 0000000..93626bf --- /dev/null +++ b/src/text_recognizer/networks/cnn_transformer_encoder.py @@ -0,0 +1,73 @@ +"""Network with a CNN backend and a transformer encoder head.""" +from typing import Dict + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformerEncoder(nn.Module): + """A CNN backbone with Transformer Encoder frontend for sequence prediction.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict, + mlp_dim: int, + d_model: int, + nhead: int = 8, + dropout_rate: float = 0.1, + activation: str = "relu", + num_layers: int = 6, + num_classes: int = 80, + num_channels: int = 256, + max_len: int = 97, + ) -> None: + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.dropout_rate = dropout_rate + self.activation = activation + self.num_layers = num_layers + + self.backbone = configure_backbone(backbone, backbone_args) + self.position_encoding = PositionalEncoding(d_model, dropout_rate) + self.encoder = self._configure_encoder() + + self.conv = nn.Conv2d(num_channels, max_len, kernel_size=1) + + self.mlp = nn.Linear(mlp_dim, d_model) + + self.head = nn.Linear(d_model, num_classes) + + def _configure_encoder(self) -> nn.TransformerEncoder: + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=self.nhead, + dropout=self.dropout_rate, + activation=self.activation, + ) + norm = nn.LayerNorm(self.d_model) + return nn.TransformerEncoder( + encoder_layer=encoder_layer, num_layers=self.num_layers, norm=norm + ) + + def forward(self, x: Tensor, targets: Tensor = None) -> Tensor: + """Forward pass through the network.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + x = self.conv(self.backbone(x)) + x = rearrange(x, "b c h w -> b c (h w)") + x = self.mlp(x) + x = self.position_encoding(x) + x = rearrange(x, "b c h-> c b h") + x = self.encoder(x) + x = rearrange(x, "c b h-> b c h") + logits = self.head(x) + + return logits diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py new file mode 100644 index 0000000..9747429 --- /dev/null +++ b/src/text_recognizer/networks/crnn.py @@ -0,0 +1,108 @@ +"""LSTM with CTC for handwritten text recognition within a line.""" +from typing import Dict, Tuple + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from loguru import logger +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import configure_backbone + + +class ConvolutionalRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict = None, + input_size: int = 128, + hidden_size: int = 128, + bidirectional: bool = False, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, + ) -> None: + super().__init__() + self.backbone_args = backbone_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) + self.input_size = input_size + self.hidden_size = hidden_size + self.backbone = configure_backbone(backbone, backbone_args) + self.bidirectional = bidirectional + self.avg_pool = avg_pool + + if recurrent_cell.upper() in ["LSTM", "GRU"]: + recurrent_cell = getattr(nn, recurrent_cell) + else: + logger.warning( + f"Option {recurrent_cell} not valid, defaulting to LSTM cell." + ) + recurrent_cell = nn.LSTM + + self.rnn = recurrent_cell( + input_size=self.input_size, + hidden_size=self.hidden_size, + bidirectional=bidirectional, + num_layers=num_layers, + ) + + decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size + + self.decoder = nn.Sequential( + nn.Linear(in_features=decoder_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + + x = self.backbone(x) + + # Avgerage pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + b = x.shape[0] + x = self.backbone(x) + x = rearrange(x, "b c h w -> c b (h w)", b=b) + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classifcation layer. + x = self.decoder(x) + return x diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py index 2493d5c..af9b700 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/src/text_recognizer/networks/ctc.py @@ -33,7 +33,7 @@ def greedy_decoder( """ if character_mapper is None: - character_mapper = EmnistMapper() + character_mapper = EmnistMapper(pad_token="_") # noqa: S106 predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") decoded_predictions = [] diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py new file mode 100644 index 0000000..7dc58d9 --- /dev/null +++ b/src/text_recognizer/networks/densenet.py @@ -0,0 +1,225 @@ +"""Defines a Densely Connected Convolutional Networks in PyTorch. + +Sources: +https://arxiv.org/abs/1608.06993 +https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py + +""" +from typing import List, Optional, Union + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class _DenseLayer(nn.Module): + """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" + + def __init__( + self, + in_channels: int, + growth_rate: int, + bn_size: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.dense_layer = [ + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False, + ), + nn.BatchNorm2d(bn_size * growth_rate), + activation_fn, + nn.Conv2d( + in_channels=bn_size * growth_rate, + out_channels=growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ] + if dropout_rate: + self.dense_layer.append(nn.Dropout(p=dropout_rate)) + + self.dense_layer = nn.Sequential(*self.dense_layer) + + def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: + if isinstance(x, list): + x = torch.cat(x, 1) + return self.dense_layer(x) + + +class _DenseBlock(nn.Module): + def __init__( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.dense_block = self._build_dense_blocks( + num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, + ) + + def _build_dense_blocks( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> nn.ModuleList: + dense_block = [] + for i in range(num_layers): + dense_block.append( + _DenseLayer( + in_channels=in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + return nn.ModuleList(dense_block) + + def forward(self, x: Tensor) -> Tensor: + feature_maps = [x] + for layer in self.dense_block: + x = layer(feature_maps) + feature_maps.append(x) + return torch.cat(feature_maps, 1) + + +class _Transition(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.transition = nn.Sequential( + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + bias=False, + ), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.transition(x) + + +class DenseNet(nn.Module): + """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" + + def __init__( + self, + growth_rate: int = 32, + block_config: List[int] = (6, 12, 24, 16), + in_channels: int = 1, + base_channels: int = 64, + num_classes: int = 80, + bn_size: int = 4, + dropout_rate: float = 0, + classifier: bool = True, + activation: str = "relu", + ) -> None: + super().__init__() + self.densenet = self._configure_densenet( + in_channels, + base_channels, + num_classes, + growth_rate, + block_config, + bn_size, + dropout_rate, + classifier, + activation, + ) + + def _configure_densenet( + self, + in_channels: int, + base_channels: int, + num_classes: int, + growth_rate: int, + block_config: List[int], + bn_size: int, + dropout_rate: float, + classifier: bool, + activation: str, + ) -> nn.Sequential: + activation_fn = activation_function(activation) + densenet = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(base_channels), + activation_fn, + ] + + num_features = base_channels + + for i, num_layers in enumerate(block_config): + densenet.append( + _DenseBlock( + num_layers=num_layers, + in_channels=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + densenet.append( + _Transition( + in_channels=num_features, + out_channels=num_features // 2, + activation=activation, + ) + ) + num_features = num_features // 2 + + densenet.append(activation_fn) + + if classifier: + densenet.append(nn.AdaptiveAvgPool2d((1, 1))) + densenet.append(Rearrange("b c h w -> b (c h w)")) + densenet.append( + nn.Linear(in_features=num_features, out_features=num_classes) + ) + + return nn.Sequential(*densenet) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of Densenet.""" + # If batch dimenstion is missing, it will be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.densenet(x) diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py index 53c575e..527e1a0 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/src/text_recognizer/networks/lenet.py @@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange import torch from torch import nn -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function class LeNet(nn.Module): @@ -63,6 +63,6 @@ class LeNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] return self.layers(x) diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py deleted file mode 100644 index 9009f94..0000000 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ /dev/null @@ -1,120 +0,0 @@ -"""LSTM with CTC for handwritten text recognition within a line.""" -import importlib -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange, Reduce -from loguru import logger -import torch -from torch import nn -from torch import Tensor - - -class LineRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - flatten: bool = True, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = self._configure_sliding_window() - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = self._configure_backbone(backbone) - self.bidirectional = bidirectional - self.flatten = flatten - - if self.flatten: - self.fc = nn.Linear( - in_features=self.input_size, out_features=self.hidden_size - ) - - self.rnn = nn.LSTM( - input_size=self.hidden_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_backbone(self, backbone: str) -> Type[nn.Module]: - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in self.backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[ - 2 - ] / self.backbone_args.pop("pretrained") - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: - for params in backbone.parameters(): - params.requires_grad = False - - return backbone - else: - return backbone_(**self.backbone_args) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) == 3: - x = x.unsqueeze(0) - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.backbone(x) - - # Avgerage pooling. - x = ( - self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) - if self.flatten - else rearrange(x, "(b t) h -> t b h", b=b, t=t) - ) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classifcation layer. - x = self.decoder(x) - return x diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py new file mode 100644 index 0000000..cf9fa0d --- /dev/null +++ b/src/text_recognizer/networks/loss.py @@ -0,0 +1,69 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +import torch +from torch import nn +from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F + +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py deleted file mode 100644 index 73e0641..0000000 --- a/src/text_recognizer/networks/losses.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers -from torch import nn -from torch import Tensor - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py deleted file mode 100644 index 1f853e9..0000000 --- a/src/text_recognizer/networks/misc.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Miscellaneous neural network functionality.""" -from typing import Tuple, Type - -from einops import rearrange -import torch -from torch import nn - - -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the slidning window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] - ) - return patches - - -def activation_function(activation: str) -> Type[nn.Module]: - """Returns the callable activation function.""" - activation_fns = nn.ModuleDict( - [ - ["elu", nn.ELU(inplace=True)], - ["gelu", nn.GELU()], - ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], - ["none", nn.Identity()], - ["relu", nn.ReLU(inplace=True)], - ["selu", nn.SELU(inplace=True)], - ] - ) - return activation_fns[activation.lower()] diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py index d66af28..1101912 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/src/text_recognizer/networks/mlp.py @@ -5,7 +5,7 @@ from einops.layers.torch import Rearrange import torch from torch import nn -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function class MLP(nn.Module): @@ -63,8 +63,8 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """The feedforward pass.""" # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] return self.layers(x) @property diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 046600d..6405192 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,8 +7,8 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.misc import activation_function from text_recognizer.networks.stn import SpatialTransformerNetwork +from text_recognizer.networks.util import activation_function class Conv2dAuto(nn.Conv2d): @@ -225,8 +225,8 @@ class ResidualNetworkEncoder(nn.Module): in_channels=in_channels, out_channels=self.block_sizes[0], kernel_size=3, - stride=2, - padding=3, + stride=1, + padding=1, bias=False, ), nn.BatchNorm2d(self.block_sizes[0]), diff --git a/src/text_recognizer/networks/sparse_mlp.py b/src/text_recognizer/networks/sparse_mlp.py new file mode 100644 index 0000000..53cf166 --- /dev/null +++ b/src/text_recognizer/networks/sparse_mlp.py @@ -0,0 +1,78 @@ +"""Defines the Sparse MLP network.""" +from typing import Callable, Dict, List, Optional, Union +import warnings + +from einops.layers.torch import Rearrange +from pytorch_block_sparse import BlockSparseLinear +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + +warnings.filterwarnings("ignore", category=DeprecationWarning) + + +class SparseMLP(nn.Module): + """Sparse multi layered perceptron network.""" + + def __init__( + self, + input_size: int = 784, + num_classes: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + density: float = 0.1, + activation_fn: str = "relu", + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. Defaults to 784. + num_classes (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + density (float): The density of activation at each layer. Default to 0.1. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. + + """ + super().__init__() + + activation_fn = activation_function(activation_fn) + + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + + self.layers = [ + Rearrange("b c h w -> b (c h w)"), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), + activation_fn, + ] + + for i in range(num_layers - 1): + self.layers += [ + BlockSparseLinear( + in_features=hidden_size[i], + out_features=hidden_size[i + 1], + density=density, + ), + activation_fn, + ] + + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) + ) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.layers(x) + + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py deleted file mode 100644 index c091ba0..0000000 --- a/src/text_recognizer/networks/transformer.py +++ /dev/null @@ -1,5 +0,0 @@ -"""TBC.""" -from typing import Dict - -import torch -from torch import Tensor diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py new file mode 100644 index 0000000..020a917 --- /dev/null +++ b/src/text_recognizer/networks/transformer/__init__.py @@ -0,0 +1,3 @@ +"""Transformer modules.""" +from .positional_encoding import PositionalEncoding +from .transformer import Decoder, Encoder, Transformer diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py new file mode 100644 index 0000000..cce1ecc --- /dev/null +++ b/src/text_recognizer/networks/transformer/attention.py @@ -0,0 +1,93 @@ +"""Implementes the attention module for the transformer.""" +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class MultiHeadAttention(nn.Module): + """Implementation of multihead attention.""" + + def __init__( + self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.fc_q = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_k = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_v = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) + + self._init_weights() + + self.dropout = nn.Dropout(p=dropout_rate) + + def _init_weights(self) -> None: + nn.init.normal_( + self.fc_q.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_k.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_v.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.xavier_normal_(self.fc_out.weight) + + def scaled_dot_product_attention( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + """Calculates the scaled dot product attention.""" + + # Compute the energy. + energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( + query.shape[-1] + ) + + # If we have a mask for padding some inputs. + if mask is not None: + energy = energy.masked_fill(mask == 0, -np.inf) + + # Compute the attention from the energy. + attention = torch.softmax(energy, dim=3) + + out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) + out = rearrange(out, "b head l v -> b l (head v)") + return out, attention + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + """Forward pass for computing the multihead attention.""" + # Get the query, key, and value tensor. + query = rearrange( + self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads + ) + key = rearrange( + self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads + ) + value = rearrange( + self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + ) + + out, attention = self.scaled_dot_product_attention(query, key, value, mask) + + out = self.fc_out(out) + out = self.dropout(out) + return out, attention diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py new file mode 100644 index 0000000..1ba5537 --- /dev/null +++ b/src/text_recognizer/networks/transformer/positional_encoding.py @@ -0,0 +1,32 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): + """Encodes a sense of distance or time for transformer networks.""" + + def __init__( + self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 + ) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + self.max_len = max_len + + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """Encodes the tensor with a postional embedding.""" + x = x + self.pe[:, : x.shape[1]] + return self.dropout(x) diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..c6e943e --- /dev/null +++ b/src/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,242 @@ +"""Transfomer module.""" +import copy +from typing import Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer.attention import MultiHeadAttention +from text_recognizer.networks.util import activation_function + + +def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) + + +class _IntraLayerConnection(nn.Module): + """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" + + def __init__(self, dropout_rate: float, hidden_dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(normalized_shape=hidden_dim) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, src: Tensor, residual: Tensor) -> Tensor: + return self.norm(self.dropout(src) + residual) + + +class _ConvolutionalLayer(nn.Module): + def __init__( + self, + hidden_dim: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.layer = nn.Sequential( + nn.Linear(in_features=hidden_dim, out_features=expansion_dim), + activation_function(activation), + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=expansion_dim, out_features=hidden_dim), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layer(x) + + +class EncoderLayer(nn.Module): + """Transfomer encoding layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through the encoder.""" + # First block. + # Multi head attention. + out, _ = self.self_attention(src, src, src, mask) + + # Add & norm. + out = self.block1(out, src) + + # Second block. + # Apply 1D-convolution. + cnn_out = self.cnn(out) + + # Add & norm. + out = self.block2(cnn_out, out) + + return out + + +class Encoder(nn.Module): + """Transfomer encoder module.""" + + def __init__( + self, + num_layers: int, + encoder_layer: Type[nn.Module], + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.norm = norm + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through all encoder layers.""" + for layer in self.layers: + src = layer(src, src_mask) + + if self.norm is not None: + src = self.norm(src) + + return src + + +class DecoderLayer(nn.Module): + """Transfomer decoder layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float = 0.0, + activation: str = "relu", + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.multihead_attention = MultiHeadAttention( + hidden_dim, num_heads, dropout_rate + ) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass of the layer.""" + out, _ = self.self_attention(trg, trg, trg, trg_mask) + trg = self.block1(out, trg) + + out, _ = self.multihead_attention(trg, memory, memory, memory_mask) + trg = self.block2(out, trg) + + out = self.cnn(trg) + out = self.block3(out, trg) + + return out + + +class Decoder(nn.Module): + """Transfomer decoder module.""" + + def __init__( + self, + decoder_layer: Type[nn.Module], + num_layers: int, + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the decoder.""" + for layer in self.layers: + trg = layer(trg, memory, trg_mask, memory_mask) + + if self.norm is not None: + trg = self.norm(trg) + + return trg + + +class Transformer(nn.Module): + """Transformer network.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + # Configure encoder. + encoder_norm = nn.LayerNorm(hidden_dim) + encoder_layer = EncoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) + + # Configure decoder. + decoder_norm = nn.LayerNorm(hidden_dim) + decoder_layer = DecoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src: Tensor, + trg: Tensor, + src_mask: Optional[Tensor] = None, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the transformer.""" + if src.shape[0] != trg.shape[0]: + print(trg.shape) + raise RuntimeError("The batch size of the src and trg must be the same.") + if src.shape[2] != trg.shape[2]: + raise RuntimeError( + "The number of features for the src and trg must be the same." + ) + + memory = self.encoder(src, src_mask) + output = self.decoder(trg, memory, trg_mask, memory_mask) + return output diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py new file mode 100644 index 0000000..b31e640 --- /dev/null +++ b/src/text_recognizer/networks/util.py @@ -0,0 +1,83 @@ +"""Miscellaneous neural network functionality.""" +import importlib +from pathlib import Path +from typing import Dict, Tuple, Type + +from einops import rearrange +from loguru import logger +import torch +from torch import nn + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) + # Preform the slidning window, unsqueeze as the channel dimesion is lost. + c = images.shape[1] + patches = unfold(images) + patches = rearrange( + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], + ) + return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["elu", nn.ELU(inplace=True)], + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] + + +def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: + """Loads a backbone network.""" + network_module = importlib.import_module("text_recognizer.networks") + backbone_ = getattr(network_module, backbone) + + if "pretrained" in backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( + "pretrained" + ) + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if "freeze" in backbone_args and backbone_args["freeze"] is True: + for params in backbone.parameters(): + params.requires_grad = False + + else: + backbone_ = getattr(network_module, backbone) + backbone = backbone_(**backbone_args) + + if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: + backbone = nn.Sequential( + *list(backbone.children())[:][: -backbone_args["remove_layers"]] + ) + + return backbone diff --git a/src/text_recognizer/networks/vision_transformer.py b/src/text_recognizer/networks/vision_transformer.py new file mode 100644 index 0000000..f227954 --- /dev/null +++ b/src/text_recognizer/networks/vision_transformer.py @@ -0,0 +1,159 @@ +"""VisionTransformer module. + +Splits each image into patches and feeds them to a transformer. + +""" + +from typing import Dict, Optional, Tuple, Type + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from loguru import logger +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import configure_backbone + + +class VisionTransformer(nn.Module): + """Linear projection+Transfomer for image to sequence prediction, sort of based on the ideas from ViT.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + max_len: int, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + mlp_dim: Optional[int] = None, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + activation: str = "gelu", + backbone: Optional[str] = None, + backbone_args: Optional[Dict] = None, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.stride = stride + self.trg_pad_index = trg_pad_index + self.slidning_window = self._configure_sliding_window() + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate, max_len) + self.mlp_dim = mlp_dim + + self.use_backbone = False + if backbone is None: + self.linear_projection = nn.Linear( + self.patch_size[0] * self.patch_size[1], hidden_dim + ) + else: + self.backbone = configure_backbone(backbone, backbone_args) + if mlp_dim: + self.mlp = nn.Linear(mlp_dim, hidden_dim) + self.use_backbone = True + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def _backbone(self, x: Tensor) -> Tensor: + b, t = x.shape[:2] + if self.use_backbone: + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + x = self.backbone(x) + if self.mlp_dim: + x = rearrange(x, "(b t) c h w -> b t (c h w)", b=b, t=t) + x = self.mlp(x) + else: + x = rearrange(x, "(b t) h -> b t h", b=b, t=t) + else: + x = rearrange(x, "b t c h w -> b t (c h w)", b=b, t=t) + x = self.linear_projection(x) + return x + + def preprocess_input(self, src: Tensor) -> Tensor: + """Encodes src with a backbone network and a positional encoding. + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimenstion is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src = self.slidning_window(src) # .squeeze(-2) + src = self._backbone(src) + src = self.position_encoding(src) + return src + + def preprocess_target(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + trg_mask = self._create_trg_mask(trg) + trg = self.character_embedding(trg.long()) + trg = self.position_encoding(trg) + return trg, trg_mask + + def forward(self, x: Tensor, trg: Tensor) -> Tensor: + """Forward pass with vision transfomer.""" + src = self.preprocess_input(x) + trg, trg_mask = self.preprocess_target(trg) + out = self.transformer(src, trg, trg_mask=trg_mask) + logits = self.head(out) + return logits diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py index 618f414..aa79c12 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/src/text_recognizer/networks/wide_resnet.py @@ -8,7 +8,7 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.misc import activation_function +from text_recognizer.networks.util import activation_function def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: @@ -206,8 +206,8 @@ class WideResidualNetwork(nn.Module): def forward(self, x: Tensor) -> Tensor: """Feedforward pass.""" - if len(x.shape) == 3: - x = x.unsqueeze(0) + if len(x.shape) < 4: + x = x[(None,) * int(4 - len(x.shape))] x = self.encoder(x) if self.decoder is not None: x = self.decoder(x) diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py index 5dd1a81..c04860d 100644 --- a/src/text_recognizer/tests/support/create_emnist_support_files.py +++ b/src/text_recognizer/tests/support/create_emnist_support_files.py @@ -2,10 +2,8 @@ from pathlib import Path import shutil -from text_recognizer.datasets.emnist_dataset import ( - fetch_emnist_dataset, - load_emnist_mapping, -) +from text_recognizer.datasets.emnist_dataset import EmnistDataset +from text_recognizer.datasets.util import EmnistMapper from text_recognizer.util import write_image SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" @@ -16,15 +14,16 @@ def create_emnist_support_files() -> None: shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) SUPPORT_DIRNAME.mkdir() - dataset = fetch_emnist_dataset(split="byclass", train=False) - mapping = load_emnist_mapping() + dataset = EmnistDataset(train=False) + dataset.load_or_generate_data() + mapping = EmnistMapper() for index in [5, 7, 9]: image, label = dataset[index] if len(image.shape) == 3: image = image.squeeze(0) image = image.numpy() - label = mapping[int(label)] + label = mapping(int(label)) print(index, label) write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py new file mode 100644 index 0000000..eede4d4 --- /dev/null +++ b/src/text_recognizer/tests/test_line_predictor.py @@ -0,0 +1,35 @@ +"""Tests for LinePredictor.""" +import os +from pathlib import Path +import unittest + + +import editdistance +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +from text_recognizer.line_predictor import LinePredictor +import text_recognizer.util as util + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestEmnistLinePredictor(unittest.TestCase): + """Test LinePredictor class on the EmnistLines dataset.""" + + def test_filename(self) -> None: + """Test that LinePredictor correctly predicts on single images, for several test images.""" + predictor = LinePredictor( + dataset="EmnistLineDataset", network_fn="CNNTransformer" + ) + + for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): + pred, conf = predictor.predict(str(filename)) + true = str(filename.stem) + edit_distance = editdistance.eval(pred, true) / len(pred) + print( + f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' + ) + self.assertLess(edit_distance, 0.2) diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt new file mode 100644 index 0000000..726c723 Binary files /dev/null and b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt new file mode 100644 index 0000000..6a9a915 Binary files /dev/null and b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt deleted file mode 100644 index 676eb44..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_LeNet_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt deleted file mode 100644 index 32c83cc..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_MLP_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt deleted file mode 100644 index 9f9deee..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt deleted file mode 100644 index 0dc7eb5..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt deleted file mode 100644 index e720299..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_EmnistDataset_SpinalVGG_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt new file mode 100644 index 0000000..2d5a89b Binary files /dev/null and b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt differ diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt deleted file mode 100644 index ed73c09..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_LeNet_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt b/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt deleted file mode 100644 index 4ec12c1..0000000 Binary files a/src/text_recognizer/weights/CharacterModel_Emnist_MLP_weights.pt and /dev/null differ diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt index 93d34d7..7fe1fa3 100644 Binary files a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt and b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt differ -- cgit v1.2.3-70-g09d2