diff options
Diffstat (limited to 'text_recognizer')
78 files changed, 7439 insertions, 0 deletions
diff --git a/text_recognizer/__init__.py b/text_recognizer/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/text_recognizer/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py new file mode 100644 index 0000000..ad71289 --- /dev/null +++ b/text_recognizer/character_predictor.py @@ -0,0 +1,29 @@ +"""CharacterPredictor class.""" +from typing import Dict, Tuple, Type, Union + +import numpy as np +from torch import nn + +from text_recognizer import datasets, networks +from text_recognizer.models import CharacterModel +from text_recognizer.util import read_image + + +class CharacterPredictor: + """Recognizes the character in handwritten character images.""" + + def __init__(self, network_fn: str, dataset: str) -> None: + """Intializes the CharacterModel and load the pretrained weights.""" + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = CharacterModel(network_fn=network_fn, dataset=dataset) + self.model.eval() + self.model.use_swa_model() + + def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: + """Predict on a single images contianing a handwritten character.""" + if isinstance(image_or_filename, str): + image = read_image(image_or_filename, grayscale=True) + else: + image = image_or_filename + return self.model.predict_on_image(image) diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py new file mode 100644 index 0000000..a6c1c59 --- /dev/null +++ b/text_recognizer/datasets/__init__.py @@ -0,0 +1,39 @@ +"""Dataset modules.""" +from .emnist_dataset import EmnistDataset +from .emnist_lines_dataset import ( + construct_image_from_string, + EmnistLinesDataset, + get_samples_by_character, +) +from .iam_dataset import IamDataset +from .iam_lines_dataset import IamLinesDataset +from .iam_paragraphs_dataset import IamParagraphsDataset +from .iam_preprocessor import load_metadata, Preprocessor +from .transforms import AddTokens, Transpose +from .util import ( + _download_raw_dataset, + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, + ESSENTIALS_FILENAME, +) + +__all__ = [ + "_download_raw_dataset", + "AddTokens", + "compute_sha256", + "construct_image_from_string", + "DATA_DIRNAME", + "download_url", + "EmnistDataset", + "EmnistMapper", + "EmnistLinesDataset", + "get_samples_by_character", + "load_metadata", + "IamDataset", + "IamLinesDataset", + "IamParagraphsDataset", + "Preprocessor", + "Transpose", +] diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py new file mode 100644 index 0000000..e794605 --- /dev/null +++ b/text_recognizer/datasets/dataset.py @@ -0,0 +1,152 @@ +"""Abstract dataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils import data +from torchvision.transforms import ToTensor + +import text_recognizer.datasets.transforms as transforms +from text_recognizer.datasets.util import EmnistMapper + + +class Dataset(data.Dataset): + """Abstract class for with common methods for all datasets.""" + + def __init__( + self, + train: bool, + subsample_fraction: float = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Initialization of Dataset class. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): Only use lower case letters. Defaults to False. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + self.train = train + self.split = "train" if self.train else "test" + + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower + ) + self._input_shape = self._mapper.input_shape + self._output_shape = self._mapper._num_classes + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) + + self._data = None + self._targets = None + + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + raise NotImplementedError + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Raises: + NotImplementedError: If the method is not implemented in child class. + + """ + raise NotImplementedError + + def __repr__(self) -> str: + """Returns information about the dataset.""" + raise NotImplementedError diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py new file mode 100644 index 0000000..9884fdf --- /dev/null +++ b/text_recognizer/datasets/emnist_dataset.py @@ -0,0 +1,131 @@ +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" + +import json +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from torchvision.transforms import Compose, ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.transforms import Transpose +from text_recognizer.datasets.util import DATA_DIRNAME + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + pad_token: str = None, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + pad_token (str): The pad token symbol. Defaults to _. + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + """ + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + pad_token=pad_token, + ) + + self.sample_to_balance = sample_to_balance + + # Have to transpose the emnist characters, ToTensor norms input between [0,1]. + if transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = None + + self.seed = seed + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[Tensor, Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self._data + y = self._targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self._data = x_sampled + self._targets = y_sampled + + def load_or_generate_data(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self._data = dataset.data + self._targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..2a0648a --- /dev/null +++ b/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} diff --git a/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py new file mode 100644 index 0000000..1992446 --- /dev/null +++ b/text_recognizer/datasets/emnist_lines_dataset.py @@ -0,0 +1,359 @@ +"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" + +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose +from text_recognizer.datasets.sentence_generator import SentenceGenerator +from text_recognizer.datasets.util import ( + DATA_DIRNAME, + EmnistMapper, + ESSENTIALS_FILENAME, +) + +DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" + +MAX_WIDTH = 952 + + +class EmnistLinesDataset(Dataset): + """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" + + def __init__( + self, + train: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + subsample_fraction: float = None, + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Set attributes and loads the dataset. + + Args: + train (bool): Flag for the filename. Defaults to False. Defaults to None. + transform (Optional[Callable]): The transform of the data. Defaults to None. + target_transform (Optional[Callable]): The transform of the target. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. + max_length (int): The maximum number of characters. Defaults to 34. + min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. + max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. + num_samples (int): Number of samples to generate. Defaults to 10000. + seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. + + """ + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + transform=transform, + target_transform=target_transform, + subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=self.pad_token, + eos_token=eos_token, + lower=lower, + ) + + # Extract dataset information. + self._input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes + + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_samples = num_samples + self._input_shape = ( + self.input_shape[0], + self.input_shape[1] * self.max_length, + ) + self._output_shape = (self.max_length, self.num_classes) + self.seed = seed + + # Placeholders for the dataset. + self._data = None + self._target = None + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Lines Dataset\n" # pylint: disable=no-member + f"Max length: {self.max_length}\n" + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Data: {self.data.shape}\n" + f"Tagets: {self.targets.shape}\n" + ) + + @property + def data_filename(self) -> Path: + """Path to the h5 file.""" + filename = "train.pt" if self.train else "test.pt" + return DATA_DIRNAME / filename + + def load_or_generate_data(self) -> None: + """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" + np.random.seed(self.seed) + + if not self.data_filename.exists(): + self._generate_data() + self._load_data() + self._subsample() + + def _load_data(self) -> None: + """Loads the dataset from the h5 file.""" + logger.debug("EmnistLinesDataset loading data from HDF5...") + with h5py.File(self.data_filename, "r") as f: + self._data = f["data"][()] + self._targets = f["targets"][()] + + def _generate_data(self) -> str: + """Generates a dataset with the Brown corpus and Emnist characters.""" + logger.debug("Generating data...") + + sentence_generator = SentenceGenerator(self.max_length) + + # Load emnist dataset. + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) + emnist.load_or_generate_data() + + samples_by_character = get_samples_by_character( + emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, + ) + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + data, targets = create_dataset_of_images( + self.num_samples, + samples_by_character, + sentence_generator, + self.min_overlap, + self.max_overlap, + ) + + targets = convert_strings_to_categorical_labels( + targets, emnist.inverse_mapping + ) + + f.create_dataset("data", data=data, dtype="u1", compression="lzf") + f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") + + +def get_samples_by_character( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + """Creates a dictionary with character as key and value as the list of images of that character. + + Args: + samples (np.ndarray): Dataset of images of characters. + labels (np.ndarray): The labels for each image. + mapping (Dict): The Emnist mapping dictionary. + + Returns: + defaultdict: A dictionary with characters as keys and list of images as values. + + """ + samples_by_character = defaultdict(list) + for sample, label in zip(samples, labels.flatten()): + samples_by_character[mapping[label]].append(sample) + return samples_by_character + + +def select_letter_samples_for_string( + string: str, samples_by_character: Dict +) -> List[np.ndarray]: + """Randomly selects Emnist characters to use for the senetence. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + + Returns: + List[np.ndarray]: A list of emnist images of the string. + + """ + zero_image = np.zeros((28, 28), np.uint8) + sample_image_by_character = {} + for character in string: + if character in sample_image_by_character: + continue + samples = samples_by_character[character] + sample = samples[np.random.choice(len(samples))] if samples else zero_image + sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) + return [sample_image_by_character[character] for character in string] + + +def construct_image_from_string( + string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float +) -> np.ndarray: + """Concatenates images of the characters in the string. + + The concatination is made with randomly selected overlap so that some portion of the character will overlap. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + np.ndarray: The Emnist image of the string. + + """ + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = select_letter_samples_for_string(string, samples_by_character) + length = len(sampled_images) + height, width = sampled_images[0].shape + next_overlap_width = width - int(overlap * width) + concatenated_image = np.zeros((height, width * length), np.uint8) + x = 0 + for image in sampled_images: + concatenated_image[:, x : (x + width)] += image + x += next_overlap_width + + if concatenated_image.shape[-1] > MAX_WIDTH: + concatenated_image = Tensor(concatenated_image).unsqueeze(0) + concatenated_image = F.interpolate( + concatenated_image, size=MAX_WIDTH, mode="nearest" + ) + concatenated_image = concatenated_image.squeeze(0).numpy() + + return np.minimum(255, concatenated_image) + + +def create_dataset_of_images( + length: int, + samples_by_character: Dict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, +) -> Tuple[np.ndarray, List[str]]: + """Creates a dataset with images and labels from strings generated from the SentenceGenerator. + + Args: + length (int): The number of characters for each string. + samples_by_character (Dict): The dictionary of emnist images of each character. + sentence_generator (SentenceGenerator): A SentenceGenerator objest. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). + + Raises: + RuntimeError: If the sentence generator is not able to generate a string. + + """ + sample_label = sentence_generator.generate() + sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) + images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) + labels = [] + for n in range(length): + label = None + # Try several times to generate before actually throwing an error. + for _ in range(10): + try: + label = sentence_generator.generate() + break + except Exception: # pylint: disable=broad-except + pass + if label is None: + raise RuntimeError("Was not able to generate a valid string.") + images[n] = construct_image_from_string( + label, samples_by_character, min_overlap, max_overlap + ) + labels.append(label) + return images, labels + + +def convert_strings_to_categorical_labels( + labels: List[str], mapping: Dict +) -> np.ndarray: + """Translates a string of characters in to a target array of class int.""" + return np.array([[mapping[c] for c in label] for label in labels]) + + +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") +def create_datasets( + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_train: int = 10000, + num_test: int = 1000, +) -> None: + """Creates a training an validation dataset of Emnist lines.""" + num_samples = [num_train, num_test] + for num, train in zip(num_samples, [True, False]): + emnist_lines = EmnistLinesDataset( + train=train, + max_length=max_length, + min_overlap=min_overlap, + max_overlap=max_overlap, + num_samples=num, + ) + emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py new file mode 100644 index 0000000..a8998b9 --- /dev/null +++ b/text_recognizer/datasets/iam_dataset.py @@ -0,0 +1,133 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml + +from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" +RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset: + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py new file mode 100644 index 0000000..1cb84bd --- /dev/null +++ b/text_recognizer/datasets/iam_lines_dataset.py @@ -0,0 +1,110 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, + lower=lower, + ) + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] if self.data is not None else None + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] + self._subsample() + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py new file mode 100644 index 0000000..8ba5142 --- /dev/null +++ b/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -0,0 +1,291 @@ +"""IamParagraphsDataset class and functions for data processing.""" +import random +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self._ids = None + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + seed = np.random.randint(SEED) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.transform: + data = self.transform(data) + + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets.long() + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=None, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py new file mode 100644 index 0000000..a93eb00 --- /dev/null +++ b/text_recognizer/datasets/iam_preprocessor.py @@ -0,0 +1,196 @@ +"""Preprocessor for extracting word letters from the IAM dataset. + +The code is mostly stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + +""" + +import collections +import itertools +from pathlib import Path +import re +from typing import List, Optional, Union + +import click +from loguru import logger +import torch + + +def load_metadata( + data_dir: Path, wordsep: str, use_words: bool = False +) -> collections.defaultdict: + """Loads IAM metadata and returns it as a dictionary.""" + forms = collections.defaultdict(list) + filename = "words.txt" if use_words else "lines.txt" + + with open(data_dir / "ascii" / filename, "r") as f: + lines = (line.strip().split() for line in f if line[0] != "#") + for line in lines: + # Skip word segmentation errors. + if use_words and line[1] == "err": + continue + text = " ".join(line[8:]) + + # Remove garbage tokens: + text = text.replace("#", "") + + # Swap word sep form | to wordsep + text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) + form_key = "-".join(line[0].split("-")[:2]) + line_key = "-".join(line[0].split("-")[:3]) + box_idx = 4 - use_words + box = tuple(int(val) for val in line[box_idx : box_idx + 4]) + forms[form_key].append({"key": line_key, "box": box, "text": text}) + return forms + + +class Preprocessor: + """A preprocessor for the IAM dataset.""" + + # TODO: add lower case only to when generating... + + def __init__( + self, + data_dir: Union[str, Path], + num_features: int, + tokens_path: Optional[Union[str, Path]] = None, + lexicon_path: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + self.wordsep = "▁" + self._use_word = use_words + self._prepend_wordsep = prepend_wordsep + + self.data_dir = Path(data_dir) + + self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) + + # Load the set of graphemes: + graphemes = set() + for _, form in self.forms.items(): + for line in form: + graphemes.update(line["text"].lower()) + self.graphemes = sorted(graphemes) + + # Build the token-to-index and index-to-token maps. + if tokens_path is not None: + with open(tokens_path, "r") as f: + self.tokens = [line.strip() for line in f] + else: + self.tokens = self.graphemes + + if lexicon_path is not None: + with open(lexicon_path, "r") as f: + lexicon = (line.strip().split() for line in f) + lexicon = {line[0]: line[1:] for line in lexicon} + self.lexicon = lexicon + else: + self.lexicon = None + + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} + self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} + self.num_features = num_features + self.text = [] + + @property + def num_tokens(self) -> int: + """Returns the number or tokens.""" + return len(self.tokens) + + @property + def use_words(self) -> bool: + """If words are used.""" + return self._use_word + + def extract_train_text(self) -> None: + """Extracts training text.""" + keys = [] + with open(self.data_dir / "task" / "trainset.txt") as f: + keys.extend((line.strip() for line in f)) + + for _, examples in self.forms.items(): + for example in examples: + if example["key"] not in keys: + continue + self.text.append(example["text"].lower()) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + token_to_index = self.graphemes_to_index + if self.lexicon is not None: + if len(line) > 0: + # If the word is not found in the lexicon, fall back to letters. + line = [ + t + for w in line.split(self.wordsep) + for t in self.lexicon.get(w, self.wordsep + w) + ] + token_to_index = self.tokens_to_index + if self._prepend_wordsep: + line = itertools.chain([self.wordsep], line) + return torch.LongTensor([token_to_index[t] for t in line]) + + def to_text(self, indices: List[int]) -> str: + """Converts indices to text.""" + # Roughly the inverse of `to_index` + encoding = self.graphemes + if self.lexicon is not None: + encoding = self.tokens + return self._post_process(encoding[i] for i in indices) + + def tokens_to_text(self, indices: List[int]) -> str: + """Converts tokens to text.""" + return self._post_process(self.tokens[i] for i in indices) + + def _post_process(self, indices: List[int]) -> str: + """A list join.""" + return "".join(indices).strip(self.wordsep) + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") +@click.option( + "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" +) +@click.option( + "--save_text", type=str, default=None, help="Path to save parsed train text" +) +@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") +def cli( + data_dir: Optional[str], + use_words: bool, + save_text: Optional[str], + save_tokens: Optional[str], +) -> None: + """CLI for extracting text data from the iam dataset.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + preprocessor = Preprocessor(data_dir, 64, use_words=use_words) + preprocessor.extract_train_text() + + processed_dir = data_dir.parents[2] / "processed" / "iam_lines" + logger.debug(f"Saving processed files at: {processed_dir}") + + if save_text is not None: + logger.info("Saving training text") + with open(processed_dir / save_text, "w") as f: + f.write("\n".join(t for t in preprocessor.text)) + + if save_tokens is not None: + logger.info("Saving tokens") + with open(processed_dir / save_tokens, "w") as f: + f.write("\n".join(preprocessor.tokens)) + + +if __name__ == "__main__": + cli() diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py new file mode 100644 index 0000000..dd76652 --- /dev/null +++ b/text_recognizer/datasets/sentence_generator.py @@ -0,0 +1,81 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets.util import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + padding = "_" * (max_length - len(sampled_text)) + return sampled_text + padding + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py new file mode 100644 index 0000000..b6a48f5 --- /dev/null +++ b/text_recognizer/datasets/transforms.py @@ -0,0 +1,266 @@ +"""Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path +import random +from typing import Any, Optional, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) + +from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.datasets.util import EmnistMapper + + +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +class Resize: + """Resizes a tensor to a specified width.""" + + def __init__(self, width: int = 952) -> None: + # The default is 952 because of the IAM dataset. + self.width = width + + def __call__(self, image: Tensor) -> Tensor: + """Resize tensor in the last dimension.""" + return F.interpolate(image, size=self.width, mode="nearest") + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target + + +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high + + def __call__(self, x: Tensor) -> Tensor: + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask + + +class Unsqueeze: + """Add a dimension to the tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Adds dim.""" + return x.unsqueeze(0) + + +class Squeeze: + """Removes the first dimension of a tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Removes first dim.""" + return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py new file mode 100644 index 0000000..da87756 --- /dev/null +++ b/text_recognizer/datasets/util.py @@ -0,0 +1,209 @@ +"""Util functions for datasets.""" +import hashlib +import json +import os +from pathlib import Path +import string +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from tqdm import tqdm + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) + + +def download_emnist() -> None: + """Download the EMNIST dataset via the PyTorch class.""" + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) + + +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.lower = lower + + self.essentials = self._load_emnist_essentials() + # Load dataset information. + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self) -> None: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol, and acts as blank symbol as well. + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) + + max_key = max(self.mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + self._mapping = {**self.mapping, **extra_mapping} + + +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) diff --git a/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py new file mode 100644 index 0000000..8e348fe --- /dev/null +++ b/text_recognizer/line_predictor.py @@ -0,0 +1,28 @@ +"""LinePredictor class.""" +import importlib +from typing import Tuple, Union + +import numpy as np +from torch import nn + +from text_recognizer import datasets, networks +from text_recognizer.models import TransformerModel +from text_recognizer.util import read_image + + +class LinePredictor: + """Given an image of a line of handwritten text, recognizes the text content.""" + + def __init__(self, dataset: str, network_fn: str) -> None: + network_fn = getattr(networks, network_fn) + dataset = getattr(datasets, dataset) + self.model = TransformerModel(network_fn=network_fn, dataset=dataset) + self.model.eval() + + def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: + """Predict on a single images contianing a handwritten character.""" + if isinstance(image_or_filename, str): + image = read_image(image_or_filename, grayscale=True) + else: + image = image_or_filename + return self.model.predict_on_image(image) diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py new file mode 100644 index 0000000..7647d7e --- /dev/null +++ b/text_recognizer/models/__init__.py @@ -0,0 +1,18 @@ +"""Model modules.""" +from .base import Model +from .character_model import CharacterModel +from .crnn_model import CRNNModel +from .ctc_transformer_model import CTCTransformerModel +from .segmentation_model import SegmentationModel +from .transformer_model import TransformerModel +from .vqvae_model import VQVAEModel + +__all__ = [ + "CharacterModel", + "CRNNModel", + "CTCTransformerModel", + "Model", + "SegmentationModel", + "TransformerModel", + "VQVAEModel", +] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py new file mode 100644 index 0000000..70f4cdb --- /dev/null +++ b/text_recognizer/models/base.py @@ -0,0 +1,455 @@ +"""Abstract Model class for PyTorch neural networks.""" + +from abc import ABC, abstractmethod +from glob import glob +import importlib +from pathlib import Path +import re +import shutil +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +from loguru import logger +import torch +from torch import nn +from torch import Tensor +from torch.optim.swa_utils import AveragedModel, SWALR +from torch.utils.data import DataLoader, Dataset, random_split +from torchsummary import summary + +from text_recognizer import datasets +from text_recognizer import networks +from text_recognizer.datasets import EmnistMapper + +WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" + + +class Model(ABC): + """Abstract Model class with composition of different parts defining a PyTorch neural network.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Base class, to be inherited by model for specific type of data. + + Args: + network_fn (str): The name of network. + dataset (str): The name dataset class. + network_args (Optional[Dict]): Arguments for the network. Defaults to None. + dataset_args (Optional[Dict]): Arguments for the dataset. + metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. + criterion (Optional[Callable]): The criterion to evaluate the performance of the network. + Defaults to None. + criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. + optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. + optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. + lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. + lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to + None. + swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to + None. + device (Optional[str]): Name of the device to train on. Defaults to None. + + """ + self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" + # Has to be set in subclass. + self._mapper = None + + # Placeholder. + self._input_shape = None + + self.dataset_name = dataset + self.dataset = None + self.dataset_args = dataset_args + + # Placeholders for datasets. + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + # Stochastic Weight Averaging placeholders. + self.swa_args = swa_args + self._swa_scheduler = None + self._swa_network = None + self._use_swa_model = False + + # Experiment directory. + self.model_dir = None + + # Flag for configured model. + self.is_configured = False + self.data_prepared = False + + # Flag for stopping training. + self.stop_training = False + + self._metrics = metrics if metrics is not None else None + + # Set the device. + self._device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + + # Configure network. + self._network = None + self._network_args = network_args + self._configure_network(network_fn) + + # Place network on device (GPU). + self.to_device() + + # Loss and Optimizer placeholders for before loading. + self._criterion = criterion + self.criterion_args = criterion_args + + self._optimizer = optimizer + self.optimizer_args = optimizer_args + + self._lr_scheduler = lr_scheduler + self.lr_scheduler_args = lr_scheduler_args + + def configure_model(self) -> None: + """Configures criterion and optimizers.""" + if not self.is_configured: + self._configure_criterion() + self._configure_optimizers() + + # Set this flag to true to prevent the model from configuring again. + self.is_configured = True + + def prepare_data(self) -> None: + """Prepare data for training.""" + # TODO add downloading. + if not self.data_prepared: + # Load dataset module. + self.dataset = getattr(datasets, self.dataset_name) + + # Load train dataset. + train_dataset = self.dataset(train=True, **self.dataset_args["args"]) + train_dataset.load_or_generate_data() + + # Set input shape. + self._input_shape = train_dataset.input_shape + + # Split train dataset into a training and validation partition. + dataset_len = len(train_dataset) + train_len = int( + self.dataset_args["train_args"]["train_fraction"] * dataset_len + ) + val_len = dataset_len - train_len + self.train_dataset, self.val_dataset = random_split( + train_dataset, lengths=[train_len, val_len] + ) + + # Load test dataset. + self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) + self.test_dataset.load_or_generate_data() + + # Set the flag to true to disable ability to load data again. + self.data_prepared = True + + def train_dataloader(self) -> DataLoader: + """Returns data loader for training set.""" + return DataLoader( + self.train_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) + + def val_dataloader(self) -> DataLoader: + """Returns data loader for validation set.""" + return DataLoader( + self.val_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=True, + pin_memory=True, + ) + + def test_dataloader(self) -> DataLoader: + """Returns data loader for test set.""" + return DataLoader( + self.test_dataset, + batch_size=self.dataset_args["train_args"]["batch_size"], + num_workers=self.dataset_args["train_args"]["num_workers"], + shuffle=False, + pin_memory=True, + ) + + def _configure_network(self, network_fn: Type[nn.Module]) -> None: + """Loads the network.""" + # If no network arguments are given, load pretrained weights if they exist. + # Load network module. + network_fn = getattr(networks, network_fn) + if self._network_args is None: + self.load_weights(network_fn) + else: + self._network = network_fn(**self._network_args) + + def _configure_criterion(self) -> None: + """Loads the criterion.""" + self._criterion = ( + self._criterion(**self.criterion_args) + if self._criterion is not None + else None + ) + + def _configure_optimizers(self,) -> None: + """Loads the optimizers.""" + if self._optimizer is not None: + self._optimizer = self._optimizer( + self._network.parameters(), **self.optimizer_args + ) + else: + self._optimizer = None + + if self._optimizer and self._lr_scheduler is not None: + if "steps_per_epoch" in self.lr_scheduler_args: + self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) + + # Assume lr scheduler should update at each epoch if not specified. + if "interval" not in self.lr_scheduler_args: + interval = "epoch" + else: + interval = self.lr_scheduler_args.pop("interval") + self._lr_scheduler = { + "lr_scheduler": self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ), + "interval": interval, + } + + if self.swa_args is not None: + self._swa_scheduler = { + "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), + "swa_start": self.swa_args["start"], + } + self._swa_network = AveragedModel(self._network).to(self.device) + + @property + def name(self) -> str: + """Returns the name of the model.""" + return self._name + + @property + def input_shape(self) -> Tuple[int, ...]: + """The input shape.""" + return self._input_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the mapper that maps between ints and chars.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Returns the mapping between network output and Emnist character.""" + return self._mapper.mapping if self._mapper is not None else None + + def eval(self) -> None: + """Sets the network to evaluation mode.""" + self._network.eval() + + def train(self) -> None: + """Sets the network to train mode.""" + self._network.train() + + @property + def device(self) -> str: + """Device where the weights are stored, i.e. cpu or cuda.""" + return self._device + + @property + def metrics(self) -> Optional[Dict]: + """Metrics.""" + return self._metrics + + @property + def criterion(self) -> Optional[Callable]: + """Criterion.""" + return self._criterion + + @property + def optimizer(self) -> Optional[Callable]: + """Optimizer.""" + return self._optimizer + + @property + def lr_scheduler(self) -> Optional[Dict]: + """Returns a directory with the learning rate scheduler.""" + return self._lr_scheduler + + @property + def swa_scheduler(self) -> Optional[Dict]: + """Returns a directory with the stochastic weight averaging scheduler.""" + return self._swa_scheduler + + @property + def swa_network(self) -> Optional[Callable]: + """Returns the stochastic weight averaging network.""" + return self._swa_network + + @property + def network(self) -> Type[nn.Module]: + """Neural network.""" + # Returns the SWA network if available. + return self._network + + @property + def weights_filename(self) -> str: + """Filepath to the network weights.""" + WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) + return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") + + def use_swa_model(self) -> None: + """Set to use predictions from SWA model.""" + if self.swa_network is not None: + self._use_swa_model = True + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass with the network.""" + if self._use_swa_model: + return self.swa_network(x) + else: + return self.network(x) + + def summary( + self, + input_shape: Optional[Union[List, Tuple]] = None, + depth: int = 3, + device: Optional[str] = None, + ) -> None: + """Prints a summary of the network architecture.""" + device = self.device if device is None else device + + if input_shape is not None: + summary(self.network, input_shape, depth=depth, device=device) + elif self._input_shape is not None: + input_shape = tuple(self._input_shape) + summary(self.network, input_shape, depth=depth, device=device) + else: + logger.warning("Could not print summary as input shape is not set.") + + def to_device(self) -> None: + """Places the network on the device (GPU).""" + self._network.to(self._device) + + def _get_state_dict(self) -> Dict: + """Get the state dict of the model.""" + state = {"model_state": self._network.state_dict()} + + if self._optimizer is not None: + state["optimizer_state"] = self._optimizer.state_dict() + + if self._lr_scheduler is not None: + state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() + state["scheduler_interval"] = self._lr_scheduler["interval"] + + if self._swa_network is not None: + state["swa_network"] = self._swa_network.state_dict() + + return state + + def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: + """Load a previously saved checkpoint. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + + """ + checkpoint_path = Path(checkpoint_path) + self.prepare_data() + self.configure_model() + logger.debug("Loading checkpoint...") + if not checkpoint_path.exists(): + logger.debug("File does not exist {str(checkpoint_path)}") + + checkpoint = torch.load(str(checkpoint_path), map_location=self.device) + self._network.load_state_dict(checkpoint["model_state"]) + + if self._optimizer is not None: + self._optimizer.load_state_dict(checkpoint["optimizer_state"]) + + if self._lr_scheduler is not None: + # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs + # with OneCycleLR. + if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": + self._lr_scheduler["lr_scheduler"].load_state_dict( + checkpoint["scheduler_state"] + ) + self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] + + if self._swa_network is not None: + self._swa_network.load_state_dict(checkpoint["swa_network"]) + + def save_checkpoint( + self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str + ) -> None: + """Saves a checkpoint of the model. + + Args: + checkpoint_path (Path): Path to the experiment with the checkpoint. + is_best (bool): If it is the currently best model. + epoch (int): The epoch of the checkpoint. + val_metric (str): Validation metric. + + """ + state = self._get_state_dict() + state["is_best"] = is_best + state["epoch"] = epoch + state["network_args"] = self._network_args + + checkpoint_path.mkdir(parents=True, exist_ok=True) + + logger.debug("Saving checkpoint...") + filepath = str(checkpoint_path / "last.pt") + torch.save(state, filepath) + + if is_best: + logger.debug( + f"Found a new best {val_metric}. Saving best checkpoint and weights." + ) + shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) + + def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: + """Load the network weights.""" + logger.debug("Loading network with pretrained weights.") + filename = glob(self.weights_filename)[0] + if not filename: + raise FileNotFoundError( + f"Could not find any pretrained weights at {self.weights_filename}" + ) + # Loading state directory. + state_dict = torch.load(filename, map_location=torch.device(self._device)) + self._network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + if network_fn is not None: + self._network = network_fn(**self._network_args) + self._network.load_state_dict(weights) + + if "swa_network" in state_dict: + self._swa_network = AveragedModel(self._network).to(self.device) + self._swa_network.load_state_dict(state_dict["swa_network"]) + + def save_weights(self, path: Path) -> None: + """Save the network weights.""" + logger.debug("Saving the best network weights.") + shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py new file mode 100644 index 0000000..f9944f3 --- /dev/null +++ b/text_recognizer/models/character_model.py @@ -0,0 +1,88 @@ +"""Defines the CharacterModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class CharacterModel(Model): + """Model for predicting characters from images.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image( + self, image: Union[np.ndarray, torch.Tensor] + ) -> Tuple[str, float]: + """Character prediction on an image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + logits = self.forward(image) + + prediction = self.softmax(logits.squeeze(0)) + + index = int(torch.argmax(prediction, dim=0)) + confidence_of_prediction = prediction[index] + predicted_character = self.mapper(index) + + return predicted_character, confidence_of_prediction diff --git a/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py new file mode 100644 index 0000000..1e01a83 --- /dev/null +++ b/text_recognizer/models/crnn_model.py @@ -0,0 +1,119 @@ +"""Defines the CRNNModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CRNNModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 79] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=79, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py new file mode 100644 index 0000000..25925f2 --- /dev/null +++ b/text_recognizer/models/ctc_transformer_model.py @@ -0,0 +1,120 @@ +"""Defines the CTC Transformer Model class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class CTCTransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + self.lower = dataset_args["args"]["lower"] + + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) + + self.tensor_transform = ToTensor() + + def criterion(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss. + + Args: + output (Tensor): Model predictions. + targets (Tensor): Correct output sequence. + + Returns: + Tensor: The CTC loss. + + """ + # Input lengths on the form [T, B] + input_lengths = torch.full( + size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, + ) + + # Configure target tensors for ctc loss. + targets_ = Tensor([]).to(self.device) + target_lengths = [] + for t in targets: + # Remove padding symbol as it acts as the blank symbol. + t = t[t < 53] + targets_ = torch.cat([targets_, t]) + target_lengths.append(len(t)) + + targets = targets_.type(dtype=torch.long) + target_lengths = ( + torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) + ) + + return self._criterion(output, targets, input_lengths, target_lengths) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + log_probs = self.forward(image) + + raw_pred, _ = greedy_decoder( + predictions=log_probs, + character_mapper=self.mapper, + blank_label=53, + collapse_repeated=True, + ) + + log_probs, _ = log_probs.max(dim=2) + + predicted_characters = "".join(raw_pred[0]) + confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py new file mode 100644 index 0000000..613108a --- /dev/null +++ b/text_recognizer/models/segmentation_model.py @@ -0,0 +1,75 @@ +"""Segmentation model for detecting and segmenting lines.""" +from typing import Callable, Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.models.base import Model + + +class SegmentationModel(Model): + """Model for segmenting lines in an image.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: + """Predict on a single input.""" + self.eval() + + if image.dtype is np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype is torch.uint8 or image.dtype is torch.int64: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + if not torch.is_tensor(image): + image = Tensor(image) + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + logits = self.forward(image) + + segmentation_mask = torch.argmax(logits, dim=1) + + return segmentation_mask diff --git a/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py new file mode 100644 index 0000000..3f63053 --- /dev/null +++ b/text_recognizer/models/transformer_model.py @@ -0,0 +1,124 @@ +"""Defines the CNN-Transformer class.""" +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +from torch.utils.data import Dataset + +from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms +from text_recognizer.models.base import Model +from text_recognizer.networks import greedy_decoder + + +class TransformerModel(Model): + """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" + + def __init__( + self, + network_fn: str, + dataset: str, + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.init_token = dataset_args["args"]["init_token"] + self.pad_token = dataset_args["args"]["pad_token"] + self.eos_token = dataset_args["args"]["eos_token"] + self.lower = dataset_args["args"]["lower"] + self.max_len = 100 + + if self._mapper is None: + self._mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=self.lower, + ) + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) + self.softmax = nn.Softmax(dim=2) + + @torch.no_grad() + def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: + src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + + memory = self.network.encoder(src) + + confidence_of_predictions = [] + trg_indices = [self.mapper(self.init_token)] + + for _ in range(self.max_len - 1): + trg = torch.tensor(trg_indices, device=self.device)[None, :].long() + trg = self.network.target_embedding(trg) + logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) + + # Convert logits to probabilities. + probs = self.softmax(logits) + + pred_token = probs.argmax(2)[:, -1].item() + confidence = probs.max(2).values[:, -1].item() + + trg_indices.append(pred_token) + confidence_of_predictions.append(confidence) + + if pred_token == self.mapper(self.eos_token): + break + + confidence = np.min(confidence_of_predictions) + predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) + + return predicted_characters, confidence + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: + """Predict on a single input.""" + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + + # Rescale image between 0 and 1. + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + + (predicted_characters, confidence_of_prediction,) = self._generate_sentence( + image + ) + + return predicted_characters, confidence_of_prediction diff --git a/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py new file mode 100644 index 0000000..70f6f1f --- /dev/null +++ b/text_recognizer/models/vqvae_model.py @@ -0,0 +1,80 @@ +"""Defines the VQVAEModel class.""" +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor + +from text_recognizer.datasets import EmnistMapper +from text_recognizer.models.base import Model + + +class VQVAEModel(Model): + """Model for reconstructing images from codebook.""" + + def __init__( + self, + network_fn: Type[nn.Module], + dataset: Type[Dataset], + network_args: Optional[Dict] = None, + dataset_args: Optional[Dict] = None, + metrics: Optional[Dict] = None, + criterion: Optional[Callable] = None, + criterion_args: Optional[Dict] = None, + optimizer: Optional[Callable] = None, + optimizer_args: Optional[Dict] = None, + lr_scheduler: Optional[Callable] = None, + lr_scheduler_args: Optional[Dict] = None, + swa_args: Optional[Dict] = None, + device: Optional[str] = None, + ) -> None: + """Initializes the CharacterModel.""" + + super().__init__( + network_fn, + dataset, + network_args, + dataset_args, + metrics, + criterion, + criterion_args, + optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, + swa_args, + device, + ) + self.pad_token = dataset_args["args"]["pad_token"] + if self._mapper is None: + self._mapper = EmnistMapper(pad_token=self.pad_token,) + self.tensor_transform = ToTensor() + self.softmax = nn.Softmax(dim=0) + + @torch.no_grad() + def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """Reconstruction of image. + + Args: + image (Union[np.ndarray, torch.Tensor]): An image containing a character. + + Returns: + Tuple[str, float]: The predicted character and the confidence in the prediction. + + """ + self.eval() + + if image.dtype == np.uint8: + # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. + image = self.tensor_transform(image) + if image.dtype == torch.uint8: + # If the image is an unscaled tensor. + image = image.type("torch.FloatTensor") / 255 + + # Put the image tensor on the device the model weights are on. + image = image.to(self.device) + image_reconstructed, _ = self.forward(image) + + return image_reconstructed diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py new file mode 100644 index 0000000..1521355 --- /dev/null +++ b/text_recognizer/networks/__init__.py @@ -0,0 +1,43 @@ +"""Network modules.""" +from .cnn import CNN +from .cnn_transformer import CNNTransformer +from .crnn import ConvolutionalRecurrentNetwork +from .ctc import greedy_decoder +from .densenet import DenseNet +from .lenet import LeNet +from .metrics import accuracy, cer, wer +from .mlp import MLP +from .residual_network import ResidualNetwork, ResidualNetworkEncoder +from .transducer import load_transducer_loss, TDS2d +from .transformer import Transformer +from .unet import UNet +from .util import sliding_window +from .vit import ViT +from .vq_transformer import VQTransformer +from .vqvae import VQVAE +from .wide_resnet import WideResidualNetwork + +__all__ = [ + "accuracy", + "cer", + "CNN", + "CNNTransformer", + "ConvolutionalRecurrentNetwork", + "DenseNet", + "FCN", + "greedy_decoder", + "MLP", + "LeNet", + "load_transducer_loss", + "ResidualNetwork", + "ResidualNetworkEncoder", + "sliding_window", + "UNet", + "TDS2d", + "Transformer", + "ViT", + "VQTransformer", + "VQVAE", + "wer", + "WideResidualNetwork", +] diff --git a/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py new file mode 100644 index 0000000..dccccdb --- /dev/null +++ b/text_recognizer/networks/beam.py @@ -0,0 +1,83 @@ +"""Implementation of beam search decoder for a sequence to sequence network. + +Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py + +""" +# from typing import List +# from Queue import PriorityQueue + +# from loguru import logger +# import torch +# from torch import nn +# from torch import Tensor +# import torch.nn.functional as F + + +# class Node: +# def __init__( +# self, parent: Node, target_index: int, log_prob: Tensor, length: int +# ) -> None: +# self.parent = parent +# self.target_index = target_index +# self.log_prob = log_prob +# self.length = length +# self.reward = 0.0 + +# def eval(self, alpha: float = 1.0) -> Tensor: +# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward + + +# @torch.no_grad() +# def beam_decoder( +# network, mapper, device, memory: Tensor = None, max_len: int = 97, +# ) -> Tensor: +# beam_width = 10 +# topk = 1 # How many sentences to generate. + +# trg_indices = [mapper(mapper.init_token)] + +# end_nodes = [] + +# node = Node(None, trg_indices, 0, 1) +# nodes = PriorityQueue() + +# nodes.put((node.eval(), node)) +# q_size = 1 + +# # Beam search +# for _ in range(max_len): +# if q_size > 2000: +# logger.warning("Could not decoder input") +# break + +# # Fetch the best node. +# score, n = nodes.get() +# decoder_input = n.target_index + +# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: +# end_nodes.append((score, n)) + +# # If we reached the maximum number of sentences required. +# if len(end_nodes) >= 1: +# break +# else: +# continue + +# # Forward pass with transformer. +# trg = torch.tensor(trg_indices, device=device)[None, :].long() +# trg = network.target_embedding(trg) +# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) +# log_prob = F.log_softmax(logits, dim=2) + +# log_prob, indices = torch.topk(log_prob, beam_width) + +# for new_k in range(beam_width): +# # TODO: continue from here +# token_index = indices[0][new_k].view(1, -1) +# log_p = log_prob[0][new_k].item() + +# node = Node() + +# pass + +# pass diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py new file mode 100644 index 0000000..1807bb9 --- /dev/null +++ b/text_recognizer/networks/cnn.py @@ -0,0 +1,101 @@ +"""Implementation of a simple backbone cnn network.""" +from typing import Callable, Dict, Optional, Tuple + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class CNN(nn.Module): + """LeNet network for character prediction.""" + + def __init__( + self, + channels: Tuple[int, ...] = (1, 32, 64, 128), + kernel_sizes: Tuple[int, ...] = (4, 4, 4), + strides: Tuple[int, ...] = (2, 2, 2), + max_pool_kernel: int = 2, + dropout_rate: float = 0.2, + activation: Optional[str] = "relu", + ) -> None: + """Initialization of the LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). + strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). + max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. + dropout_rate (float): The dropout rate. Defaults to 0.2. + activation (Optional[str]): The name of non-linear activation function. Defaults to relu. + + Raises: + RuntimeError: if the number of hyperparameters does not match in length. + + """ + super().__init__() + + if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): + raise RuntimeError("The number of the hyperparameters does not match.") + + self.cnn = self._build_network( + channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, + ) + + def _build_network( + self, + channels: Tuple[int, ...], + kernel_sizes: Tuple[int, ...], + strides: Tuple[int, ...], + max_pool_kernel: int, + dropout_rate: float, + activation: str, + ) -> nn.Sequential: + # Load activation function. + activation_fn = activation_function(activation) + + channels = list(channels) + in_channels = channels.pop(0) + configuration = zip(channels, kernel_sizes, strides) + + modules = nn.ModuleList([]) + + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + # Add max pool to reduce output size. + if i == len(channels) // 2: + modules.append(nn.MaxPool2d(max_pool_kernel)) + if i == 0: + modules.append( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ) + ) + else: + modules.append( + nn.Sequential( + activation_fn, + nn.BatchNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + ) + ) + + if dropout_rate: + modules.append(nn.Dropout2d(p=dropout_rate)) + + in_channels = out_channels + + return nn.Sequential(*modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.cnn(x) diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py new file mode 100644 index 0000000..9150b55 --- /dev/null +++ b/text_recognizer/networks/cnn_transformer.py @@ -0,0 +1,158 @@ +"""A CNN-Transformer for image to text recognition.""" +from typing import Dict, Optional, Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.util import configure_backbone + + +class CNNTransformer(nn.Module): + """CNN+Transfomer for image to sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + adaptive_pool_dim: Tuple, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + max_len: int, + backbone: str, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + pool_kernel: Optional[Tuple[int, int]] = None, + ) -> None: + super().__init__() + self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size + self.backbone = configure_backbone(backbone, backbone_args) + + if pool_kernel is not None: + self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) + else: + self.max_pool = None + + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.pos_dropout = nn.Dropout(p=dropout_rate) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + nn.init.normal_(self.character_embedding.weight, std=0.02) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + ) + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential( + # nn.Linear(hidden_dim, hidden_dim * 2), + # activation_function(activation), + nn.Linear(hidden_dim, vocab_size), + ) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + + src = self.backbone(src) + + if self.max_pool is not None: + src = self.max_pool(src) + + if self.adaptive_pool is not None and len(src.shape) == 4: + src = rearrange(src, "b c h w -> b w c h") + src = self.adaptive_pool(src) + src = src.squeeze(3) + elif len(src.shape) == 4: + src = rearrange(src, "b c h w -> b (h w) c") + + b, t, _ = src.shape + + src += self.src_position_embedding[:, :t] + src = self.pos_dropout(src) + + return src + + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + trg = self.character_embedding(trg.long()) + trg = self.trg_position_encoding(trg) + return trg + + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(image_features, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + image_features = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) + return logits diff --git a/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py new file mode 100644 index 0000000..778e232 --- /dev/null +++ b/text_recognizer/networks/crnn.py @@ -0,0 +1,110 @@ +"""CRNN for handwritten text recognition.""" +from typing import Dict, Tuple + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange +from loguru import logger +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import configure_backbone + + +class ConvolutionalRecurrentNetwork(nn.Module): + """Network that takes a image of a text line and predicts tokens that are in the image.""" + + def __init__( + self, + backbone: str, + backbone_args: Dict = None, + input_size: int = 128, + hidden_size: int = 128, + bidirectional: bool = False, + num_layers: int = 1, + num_classes: int = 80, + patch_size: Tuple[int, int] = (28, 28), + stride: Tuple[int, int] = (1, 14), + recurrent_cell: str = "lstm", + avg_pool: bool = False, + use_sliding_window: bool = True, + ) -> None: + super().__init__() + self.backbone_args = backbone_args or {} + self.patch_size = patch_size + self.stride = stride + self.sliding_window = ( + self._configure_sliding_window() if use_sliding_window else None + ) + self.input_size = input_size + self.hidden_size = hidden_size + self.backbone = configure_backbone(backbone, backbone_args) + self.bidirectional = bidirectional + self.avg_pool = avg_pool + + if recurrent_cell.upper() in ["LSTM", "GRU"]: + recurrent_cell = getattr(nn, recurrent_cell) + else: + logger.warning( + f"Option {recurrent_cell} not valid, defaulting to LSTM cell." + ) + recurrent_cell = nn.LSTM + + self.rnn = recurrent_cell( + input_size=self.input_size, + hidden_size=self.hidden_size, + bidirectional=bidirectional, + num_layers=num_layers, + ) + + decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size + + self.decoder = nn.Sequential( + nn.Linear(in_features=decoder_size, out_features=num_classes), + nn.LogSoftmax(dim=2), + ) + + def _configure_sliding_window(self) -> nn.Sequential: + return nn.Sequential( + nn.Unfold(kernel_size=self.patch_size, stride=self.stride), + Rearrange( + "b (c h w) t -> b t c h w", + h=self.patch_size[0], + w=self.patch_size[1], + c=1, + ), + ) + + def forward(self, x: Tensor) -> Tensor: + """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + + if self.sliding_window is not None: + # Create image patches with a sliding window kernel. + x = self.sliding_window(x) + + # Rearrange from a sequence of patches for feedforward network. + b, t = x.shape[:2] + x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) + + x = self.backbone(x) + + # Average pooling. + if self.avg_pool: + x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) + else: + x = rearrange(x, "(b t) h -> t b h", b=b, t=t) + else: + # Encode the entire image with a CNN, and use the channels as temporal dimension. + x = self.backbone(x) + x = rearrange(x, "b c h w -> b w c h") + if self.adaptive_pool is not None: + x = self.adaptive_pool(x) + x = x.squeeze(3) + + # Sequence predictions. + x, _ = self.rnn(x) + + # Sequence to classification layer. + x = self.decoder(x) + return x diff --git a/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py new file mode 100644 index 0000000..af9b700 --- /dev/null +++ b/text_recognizer/networks/ctc.py @@ -0,0 +1,58 @@ +"""Decodes the CTC output.""" +from typing import Callable, List, Optional, Tuple + +from einops import rearrange +import torch +from torch import Tensor + +from text_recognizer.datasets.util import EmnistMapper + + +def greedy_decoder( + predictions: Tensor, + targets: Optional[Tensor] = None, + target_lengths: Optional[Tensor] = None, + character_mapper: Optional[Callable] = None, + blank_label: int = 79, + collapse_repeated: bool = True, +) -> Tuple[List[str], List[str]]: + """Greedy CTC decoder. + + Args: + predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. + targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. + target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. + character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults + to None. + blank_label (int): The blank character to be ignored. Defaults to 80. + collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. + + Returns: + Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. + + """ + + if character_mapper is None: + character_mapper = EmnistMapper(pad_token="_") # noqa: S106 + + predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") + decoded_predictions = [] + decoded_targets = [] + for i, prediction in enumerate(predictions): + decoded_prediction = [] + decoded_target = [] + if targets is not None and target_lengths is not None: + for target_index in targets[i][: target_lengths[i]]: + if target_index == blank_label: + continue + decoded_target.append(character_mapper(int(target_index))) + decoded_targets.append(decoded_target) + for j, index in enumerate(prediction): + if index != blank_label: + if collapse_repeated and j != 0 and index == prediction[j - 1]: + continue + decoded_prediction.append(index.item()) + decoded_predictions.append( + [character_mapper(int(pred_index)) for pred_index in decoded_prediction] + ) + return decoded_predictions, decoded_targets diff --git a/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py new file mode 100644 index 0000000..7dc58d9 --- /dev/null +++ b/text_recognizer/networks/densenet.py @@ -0,0 +1,225 @@ +"""Defines a Densely Connected Convolutional Networks in PyTorch. + +Sources: +https://arxiv.org/abs/1608.06993 +https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py + +""" +from typing import List, Optional, Union + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class _DenseLayer(nn.Module): + """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" + + def __init__( + self, + in_channels: int, + growth_rate: int, + bn_size: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.dense_layer = [ + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=bn_size * growth_rate, + kernel_size=1, + stride=1, + bias=False, + ), + nn.BatchNorm2d(bn_size * growth_rate), + activation_fn, + nn.Conv2d( + in_channels=bn_size * growth_rate, + out_channels=growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ] + if dropout_rate: + self.dense_layer.append(nn.Dropout(p=dropout_rate)) + + self.dense_layer = nn.Sequential(*self.dense_layer) + + def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: + if isinstance(x, list): + x = torch.cat(x, 1) + return self.dense_layer(x) + + +class _DenseBlock(nn.Module): + def __init__( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.dense_block = self._build_dense_blocks( + num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, + ) + + def _build_dense_blocks( + self, + num_layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_rate: float, + activation: str = "relu", + ) -> nn.ModuleList: + dense_block = [] + for i in range(num_layers): + dense_block.append( + _DenseLayer( + in_channels=in_channels + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + return nn.ModuleList(dense_block) + + def forward(self, x: Tensor) -> Tensor: + feature_maps = [x] + for layer in self.dense_block: + x = layer(feature_maps) + feature_maps.append(x) + return torch.cat(feature_maps, 1) + + +class _Transition(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu", + ) -> None: + super().__init__() + activation_fn = activation_function(activation) + self.transition = nn.Sequential( + nn.BatchNorm2d(in_channels), + activation_fn, + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + bias=False, + ), + nn.AvgPool2d(kernel_size=2, stride=2), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.transition(x) + + +class DenseNet(nn.Module): + """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" + + def __init__( + self, + growth_rate: int = 32, + block_config: List[int] = (6, 12, 24, 16), + in_channels: int = 1, + base_channels: int = 64, + num_classes: int = 80, + bn_size: int = 4, + dropout_rate: float = 0, + classifier: bool = True, + activation: str = "relu", + ) -> None: + super().__init__() + self.densenet = self._configure_densenet( + in_channels, + base_channels, + num_classes, + growth_rate, + block_config, + bn_size, + dropout_rate, + classifier, + activation, + ) + + def _configure_densenet( + self, + in_channels: int, + base_channels: int, + num_classes: int, + growth_rate: int, + block_config: List[int], + bn_size: int, + dropout_rate: float, + classifier: bool, + activation: str, + ) -> nn.Sequential: + activation_fn = activation_function(activation) + densenet = [ + nn.Conv2d( + in_channels=in_channels, + out_channels=base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(base_channels), + activation_fn, + ] + + num_features = base_channels + + for i, num_layers in enumerate(block_config): + densenet.append( + _DenseBlock( + num_layers=num_layers, + in_channels=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + dropout_rate=dropout_rate, + activation=activation, + ) + ) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + densenet.append( + _Transition( + in_channels=num_features, + out_channels=num_features // 2, + activation=activation, + ) + ) + num_features = num_features // 2 + + densenet.append(activation_fn) + + if classifier: + densenet.append(nn.AdaptiveAvgPool2d((1, 1))) + densenet.append(Rearrange("b c h w -> b (c h w)")) + densenet.append( + nn.Linear(in_features=num_features, out_features=num_classes) + ) + + return nn.Sequential(*densenet) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of Densenet.""" + # If batch dimenstion is missing, it will be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.densenet(x) diff --git a/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py new file mode 100644 index 0000000..527e1a0 --- /dev/null +++ b/text_recognizer/networks/lenet.py @@ -0,0 +1,68 @@ +"""Implementation of the LeNet network.""" +from typing import Callable, Dict, Optional, Tuple + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class LeNet(nn.Module): + """LeNet network for character prediction.""" + + def __init__( + self, + channels: Tuple[int, ...] = (1, 32, 64), + kernel_sizes: Tuple[int, ...] = (3, 3, 2), + hidden_size: Tuple[int, ...] = (9216, 128), + dropout_rate: float = 0.2, + num_classes: int = 10, + activation_fn: Optional[str] = "relu", + ) -> None: + """Initialization of the LeNet network. + + Args: + channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). + kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). + hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. + Defaults to (9216, 128). + dropout_rate (float): The dropout rate. Defaults to 0.2. + num_classes (int): Number of classes. Defaults to 10. + activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. + + """ + super().__init__() + + activation_fn = activation_function(activation_fn) + + self.layers = [ + nn.Conv2d( + in_channels=channels[0], + out_channels=channels[1], + kernel_size=kernel_sizes[0], + ), + activation_fn, + nn.Conv2d( + in_channels=channels[1], + out_channels=channels[2], + kernel_size=kernel_sizes[1], + ), + activation_fn, + nn.MaxPool2d(kernel_sizes[2]), + nn.Dropout(p=dropout_rate), + Rearrange("b c h w -> b (c h w)"), + nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), + activation_fn, + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=hidden_size[1], out_features=num_classes), + ] + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.layers(x) diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py new file mode 100644 index 0000000..b489264 --- /dev/null +++ b/text_recognizer/networks/loss/__init__.py @@ -0,0 +1,2 @@ +"""Loss module.""" +from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py new file mode 100644 index 0000000..cf9fa0d --- /dev/null +++ b/text_recognizer/networks/loss/loss.py @@ -0,0 +1,69 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +import torch +from torch import nn +from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F + +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py new file mode 100644 index 0000000..2605731 --- /dev/null +++ b/text_recognizer/networks/metrics.py @@ -0,0 +1,123 @@ +"""Utility functions for models.""" +from typing import Optional + +from einops import rearrange +import Levenshtein as Lev +import torch +from torch import Tensor + +from text_recognizer.networks import greedy_decoder + + +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: + """Computes the accuracy. + + Args: + outputs (Tensor): The output from the network. + labels (Tensor): Ground truth labels. + pad_index (int): Padding index. + + Returns: + float: The accuracy for the batch. + + """ + + _, predicted = torch.max(outputs, dim=-1) + + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + + acc = (predicted == labels).sum().float() / labels.shape[0] + acc = acc.item() + return acc + + +def cer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: + """Computes the character error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + batch_size (Optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. + + Returns: + float: The cer for the batch. + + """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths, blank_label=blank_label, + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + prediction, target = ( + prediction.replace(" ", ""), + target.replace(" ", ""), + ) + lev_dist += Lev.distance(prediction, target) + return lev_dist / len(decoded_predictions) + + +def wer( + outputs: Tensor, + targets: Tensor, + batch_size: Optional[int] = None, + blank_label: Optional[int] = int, +) -> float: + """Computes the Word error rate. + + Args: + outputs (Tensor): The output from the network. + targets (Tensor): Ground truth labels. + batch_size (optional[int]): Batch size if target and output has been flattend. + blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. + + Returns: + float: The wer for the batch. + + """ + if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: + targets = rearrange(targets, "(b t) -> b t", b=batch_size) + outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) + + target_lengths = torch.full( + size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, + ) + decoded_predictions, decoded_targets = greedy_decoder( + outputs, targets, target_lengths, blank_label=blank_label, + ) + + lev_dist = 0 + + for prediction, target in zip(decoded_predictions, decoded_targets): + prediction = "".join(prediction) + target = "".join(target) + + b = set(prediction.split() + target.split()) + word2char = dict(zip(b, range(len(b)))) + + w1 = [chr(word2char[w]) for w in prediction.split()] + w2 = [chr(word2char[w]) for w in target.split()] + + lev_dist += Lev.distance("".join(w1), "".join(w2)) + + return lev_dist / len(decoded_predictions) diff --git a/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py new file mode 100644 index 0000000..1101912 --- /dev/null +++ b/text_recognizer/networks/mlp.py @@ -0,0 +1,73 @@ +"""Defines the MLP network.""" +from typing import Callable, Dict, List, Optional, Union + +from einops.layers.torch import Rearrange +import torch +from torch import nn + +from text_recognizer.networks.util import activation_function + + +class MLP(nn.Module): + """Multi layered perceptron network.""" + + def __init__( + self, + input_size: int = 784, + num_classes: int = 10, + hidden_size: Union[int, List] = 128, + num_layers: int = 3, + dropout_rate: float = 0.2, + activation_fn: str = "relu", + ) -> None: + """Initialization of the MLP network. + + Args: + input_size (int): The input shape of the network. Defaults to 784. + num_classes (int): Number of classes in the dataset. Defaults to 10. + hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. + num_layers (int): The number of hidden layers. Defaults to 3. + dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. + activation_fn (str): Name of the activation function in the hidden layers. Defaults to + relu. + + """ + super().__init__() + + activation_fn = activation_function(activation_fn) + + if isinstance(hidden_size, int): + hidden_size = [hidden_size] * num_layers + + self.layers = [ + Rearrange("b c h w -> b (c h w)"), + nn.Linear(in_features=input_size, out_features=hidden_size[0]), + activation_fn, + ] + + for i in range(num_layers - 1): + self.layers += [ + nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), + activation_fn, + ] + + if dropout_rate: + self.layers.append(nn.Dropout(p=dropout_rate)) + + self.layers.append( + nn.Linear(in_features=hidden_size[-1], out_features=num_classes) + ) + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The feedforward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + return self.layers(x) + + @property + def __name__(self) -> str: + """Returns the name of the network.""" + return "mlp" diff --git a/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py new file mode 100644 index 0000000..c33f419 --- /dev/null +++ b/text_recognizer/networks/residual_network.py @@ -0,0 +1,310 @@ +"""Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Rearrange, Reduce +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class Conv2dAuto(nn.Conv2d): + """Convolution with auto padding based on kernel size.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + + +def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: + """3x3 convolution with batch norm.""" + conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) + return nn.Sequential( + conv3x3(in_channels, out_channels, *args, **kwargs), + nn.BatchNorm2d(out_channels), + ) + + +class IdentityBlock(nn.Module): + """Residual with identity block.""" + + def __init__( + self, in_channels: int, out_channels: int, activation: str = "relu" + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.blocks = nn.Identity() + self.activation_fn = activation_function(activation) + self.shortcut = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self.apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + x = self.activation_fn(x) + return x + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.out_channels + + +class ResidualBlock(IdentityBlock): + """Residual with nonlinear shortcut.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: int = 1, + downsampling: int = 1, + *args, + **kwargs + ) -> None: + """Short summary. + + Args: + in_channels (int): Number of in channels. + out_channels (int): umber of out channels. + expansion (int): Expansion factor of the out channels. Defaults to 1. + downsampling (int): Downsampling factor used in stride. Defaults to 1. + *args (type): Extra arguments. + **kwargs (type): Extra key value arguments. + + """ + super().__init__(in_channels, out_channels, *args, **kwargs) + self.expansion = expansion + self.downsampling = downsampling + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.expanded_channels, + kernel_size=1, + stride=self.downsampling, + bias=False, + ), + nn.BatchNorm2d(self.expanded_channels), + ) + if self.apply_shortcut + else None + ) + + @property + def expanded_channels(self) -> int: + """Computes the expanded output channels.""" + return self.out_channels * self.expansion + + @property + def apply_shortcut(self) -> bool: + """Check if shortcut should be applied.""" + return self.in_channels != self.expanded_channels + + +class BasicBlock(ResidualBlock): + """Basic ResNet block.""" + + expansion = 1 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + bias=False, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + bias=False, + ), + ) + + +class BottleNeckBlock(ResidualBlock): + """Bottleneck block to increase depth while minimizing parameter size.""" + + expansion = 4 + + def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: + super().__init__(in_channels, out_channels, *args, **kwargs) + self.blocks = nn.Sequential( + conv_bn( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.downsampling, + ), + self.activation_fn, + conv_bn( + in_channels=self.out_channels, + out_channels=self.expanded_channels, + kernel_size=1, + ), + ) + + +class ResidualLayer(nn.Module): + """ResNet layer.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + block: BasicBlock = BasicBlock, + num_blocks: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + downsampling = 2 if in_channels != out_channels else 1 + self.blocks = nn.Sequential( + block( + in_channels, out_channels, *args, **kwargs, downsampling=downsampling + ), + *[ + block( + out_channels * block.expansion, + out_channels, + downsampling=1, + *args, + **kwargs + ) + for _ in range(num_blocks - 1) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.blocks(x) + return x + + +class ResidualNetworkEncoder(nn.Module): + """Encoder network.""" + + def __init__( + self, + in_channels: int = 1, + block_sizes: Union[int, List[int]] = (32, 64), + depths: Union[int, List[int]] = (2, 2), + activation: str = "relu", + block: Type[nn.Module] = BasicBlock, + levels: int = 1, + *args, + **kwargs + ) -> None: + super().__init__() + self.block_sizes = ( + block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels + ) + self.depths = depths if isinstance(depths, list) else [depths] * levels + self.activation = activation + self.gate = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=self.block_sizes[0], + kernel_size=7, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.block_sizes[0]), + activation_function(self.activation), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + ) + + self.blocks = self._configure_blocks(block) + + def _configure_blocks( + self, block: Type[nn.Module], *args, **kwargs + ) -> nn.Sequential: + channels = [self.block_sizes[0]] + list( + zip(self.block_sizes, self.block_sizes[1:]) + ) + blocks = [ + ResidualLayer( + in_channels=channels[0], + out_channels=channels[0], + num_blocks=self.depths[0], + block=block, + activation=self.activation, + *args, + **kwargs + ) + ] + blocks += [ + ResidualLayer( + in_channels=in_channels * block.expansion, + out_channels=out_channels, + num_blocks=num_blocks, + block=block, + activation=self.activation, + *args, + **kwargs + ) + for (in_channels, out_channels), num_blocks in zip( + channels[1:], self.depths[1:] + ) + ] + + return nn.Sequential(*blocks) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + # If batch dimenstion is missing, it needs to be added. + if len(x.shape) == 3: + x = x.unsqueeze(0) + x = self.gate(x) + x = self.blocks(x) + return x + + +class ResidualNetworkDecoder(nn.Module): + """Classification head.""" + + def __init__(self, in_features: int, num_classes: int = 80) -> None: + super().__init__() + self.decoder = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(in_features=in_features, out_features=num_classes), + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.decoder(x) + + +class ResidualNetwork(nn.Module): + """Full residual network.""" + + def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: + super().__init__() + self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) + self.decoder = ResidualNetworkDecoder( + in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, + num_classes=num_classes, + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.encoder(x) + x = self.decoder(x) + return x diff --git a/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py new file mode 100644 index 0000000..e9d216f --- /dev/null +++ b/text_recognizer/networks/stn.py @@ -0,0 +1,44 @@ +"""Spatial Transformer Network.""" + +from einops.layers.torch import Rearrange +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class SpatialTransformerNetwork(nn.Module): + """A network with differentiable attention. + + Network that learns how to perform spatial transformations on the input image in order to enhance the + geometric invariance of the model. + + # TODO: add arguments to make it more general. + + """ + + def __init__(self) -> None: + super().__init__() + # Initialize the identity transformation and its weights and biases. + linear = nn.Linear(32, 3 * 2) + linear.weight.data.zero_() + linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) + + self.theta = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.ReLU(inplace=True), + Rearrange("b c h w -> b (c h w)", h=3, w=3), + nn.Linear(in_features=10 * 3 * 3, out_features=32), + nn.ReLU(inplace=True), + linear, + Rearrange("b (row col) -> b row col", row=2, col=3), + ) + + def forward(self, x: Tensor) -> Tensor: + """The spatial transformation.""" + grid = F.affine_grid(self.theta(x), x.shape) + return F.grid_sample(x, grid, align_corners=False) diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py new file mode 100644 index 0000000..8c19a01 --- /dev/null +++ b/text_recognizer/networks/transducer/__init__.py @@ -0,0 +1,3 @@ +"""Transducer modules.""" +from .tds_conv import TDS2d +from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py new file mode 100644 index 0000000..5fb8ba9 --- /dev/null +++ b/text_recognizer/networks/transducer/tds_conv.py @@ -0,0 +1,208 @@ +"""Time-Depth Separable Convolutions. + +References: + https://arxiv.org/abs/1904.02619 + https://arxiv.org/pdf/2010.01003.pdf + +Code stolen from: + https://github.com/facebookresearch/gtn_applications + + +""" +from typing import List, Tuple + +from einops import rearrange +import gtn +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class TDSBlock2d(nn.Module): + """Internal block of a 2D TDSC network.""" + + def __init__( + self, + in_channels: int, + img_depth: int, + kernel_size: Tuple[int], + dropout_rate: float, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.img_depth = img_depth + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.fc_dim = in_channels * img_depth + + # Network placeholders. + self.conv = None + self.mlp = None + self.instance_norm = None + + self._build_block() + + def _build_block(self) -> None: + # Convolutional block. + self.conv = nn.Sequential( + nn.Conv3d( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), + padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + ) + + # MLP block. + self.mlp = nn.Sequential( + nn.Linear(self.fc_dim, self.fc_dim), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(self.fc_dim, self.fc_dim), + nn.Dropout(self.dropout_rate), + ) + + # Instance norm. + self.instance_norm = nn.ModuleList( + [ + nn.InstanceNorm2d(self.fc_dim, affine=True), + nn.InstanceNorm2d(self.fc_dim, affine=True), + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, CD, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, CD, H, W = x.shape + C, D = self.in_channels, self.img_depth + residual = x + x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) + x = self.conv(x) + x = rearrange(x, "b c d h w -> b (c d) h w") + x += residual + + x = self.instance_norm[0](x) + + x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x + x + self.instance_norm[1](x) + + # Output shape: [B, CD, H, W] + return x + + +class TDS2d(nn.Module): + """TDS Netowrk. + + Structure is the following: + Downsample layer -> TDS2d group -> ... -> Linear output layer + + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + depth: int, + tds_groups: Tuple[int], + kernel_size: Tuple[int], + dropout_rate: float, + in_channels: int = 1, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.input_dim = input_dim + self.output_dim = output_dim + self.depth = depth + self.tds_groups = tds_groups + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + self.tds = None + self.fc = None + + self._build_network() + + def _build_network(self) -> None: + in_channels = self.in_channels + modules = [] + stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) + if self.input_dim % stride_h: + raise RuntimeError( + f"Image height not divisible by total stride {stride_h}." + ) + + for tds_group in self.tds_groups: + # Add downsample layer. + out_channels = self.depth * tds_group["channels"] + modules.extend( + [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), + stride=tds_group["stride"], + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.InstanceNorm2d(out_channels, affine=True), + ] + ) + + for _ in range(tds_group["num_blocks"]): + modules.append( + TDSBlock2d( + tds_group["channels"], + self.depth, + self.kernel_size, + self.dropout_rate, + ) + ) + + in_channels = out_channels + + self.tds = nn.Sequential(*modules) + self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, H, W)` + + Returns: + Tensor: Output tensor. + + """ + if len(x.shape) == 4: + x = x.squeeze(1) # Squeeze the channel dim away. + + B, H, W = x.shape + x = rearrange( + x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels + ) + x = self.tds(x) + + # x shape: [B, C, H, W] + x = rearrange(x, "b c h w -> b w (c h)") + + return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py new file mode 100644 index 0000000..cadcecc --- /dev/null +++ b/text_recognizer/networks/transducer/test.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from text_recognizer.networks.transducer import load_transducer_loss, Transducer +import unittest + + +class TestTransducer(unittest.TestCase): + def test_viterbi(self): + T = 5 + N = 4 + B = 2 + + # fmt: off + emissions1 = torch.tensor(( + 0, 4, 0, 1, + 0, 2, 1, 1, + 0, 0, 0, 2, + 0, 0, 0, 2, + 8, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + emissions2 = torch.tensor(( + 0, 2, 1, 7, + 0, 2, 9, 1, + 0, 0, 0, 2, + 0, 0, 5, 2, + 1, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + # fmt: on + + # Test without blank: + labels = [[1, 3, 0], [3, 2, 3, 2, 3]] + transducer = Transducer( + tokens=["a", "b", "c", "d"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, + blank="none", + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + # Test with blank without repeats: + labels = [[1, 0], [2, 2]] + transducer = Transducer( + tokens=["a", "b", "c"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2}, + blank="optional", + allow_repeats=False, + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + +if __name__ == "__main__": + unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py new file mode 100644 index 0000000..d7e3d08 --- /dev/null +++ b/text_recognizer/networks/transducer/transducer.py @@ -0,0 +1,410 @@ +"""Transducer and the transducer loss function.py + +Stolen from: + https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py + +""" +from pathlib import Path +import itertools +from typing import Dict, List, Optional, Union, Tuple + +from loguru import logger +import gtn +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.datasets.iam_preprocessor import Preprocessor + + +def make_scalar_graph(weight) -> gtn.Graph: + scalar = gtn.Graph() + scalar.add_node(True) + scalar.add_node(False, True) + scalar.add_arc(0, 1, 0, 0, weight) + return scalar + + +def make_chain_graph(sequence) -> gtn.Graph: + graph = gtn.Graph(False) + graph.add_node(True) + for i, s in enumerate(sequence): + graph.add_node(False, i == (len(sequence) - 1)) + graph.add_arc(i, i + 1, s) + return graph + + +def make_transitions_graph( + ngram: int, num_tokens: int, calc_grad: bool = False +) -> gtn.Graph: + transitions = gtn.Graph(calc_grad) + transitions.add_node(True, ngram == 1) + + state_map = {(): 0} + + # First build transitions which include <s>: + for n in range(1, ngram): + for state in itertools.product(range(num_tokens), repeat=n): + in_idx = state_map[state[:-1]] + out_idx = transitions.add_node(False, ngram == 1) + state_map[state] = out_idx + transitions.add_arc(in_idx, out_idx, state[-1]) + + for state in itertools.product(range(num_tokens), repeat=ngram): + state_idx = state_map[state[:-1]] + new_state_idx = state_map[state[1:]] + # p(state[-1] | state[:-1]) + transitions.add_arc(state_idx, new_state_idx, state[-1]) + + if ngram > 1: + # Build transitions which include </s>: + end_idx = transitions.add_node(False, True) + for in_idx in range(end_idx): + transitions.add_arc(in_idx, end_idx, gtn.epsilon) + + return transitions + + +def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: + """Constructs a graph which transduces letters to word pieces.""" + graph = gtn.Graph(False) + graph.add_node(True, True) + for i, wp in enumerate(word_pieces): + prev = 0 + for l in wp[:-1]: + n = graph.add_node() + graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) + prev = n + graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) + graph.arc_sort() + return graph + + +def make_token_graph( + token_list: List, blank: str = "none", allow_repeats: bool = True +) -> gtn.Graph: + """Constructs a graph with all the individual token transition models.""" + if not allow_repeats and blank != "optional": + raise ValueError("Must use blank='optional' if disallowing repeats.") + + ntoks = len(token_list) + graph = gtn.Graph(False) + + # Creating nodes + graph.add_node(True, True) + for i in range(ntoks): + # We can consume one or more consecutive word + # pieces for each emission: + # E.g. [ab, ab, ab] transduces to [ab] + graph.add_node(False, blank != "forced") + + if blank != "none": + graph.add_node() + + # Creating arcs + if blank != "none": + # Blank index is assumed to be last (ntoks) + graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) + graph.add_arc(ntoks + 1, 0, gtn.epsilon) + + for i in range(ntoks): + graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) + graph.add_arc(i + 1, i + 1, i, gtn.epsilon) + + if allow_repeats: + if blank == "forced": + # Allow transitions from token to blank only + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + else: + # Allow transition from token to blank and all other tokens + graph.add_arc(i + 1, 0, gtn.epsilon) + + else: + # allow transitions to blank and all other tokens except the same token + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + for j in range(ntoks): + if i != j: + graph.add_arc(i + 1, j + 1, j, j) + + return graph + + +class TransducerLossFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + targets, + tokens, + lexicon, + transition_params=None, + transitions=None, + reduction="none", + ) -> Tensor: + B, T, C = inputs.shape + + losses = [None] * B + emissions_graphs = [None] * B + + if transitions is not None: + if transition_params is None: + raise ValueError("Specified transitions, but not transition params.") + + cpu_data = transition_params.cpu().contiguous() + transitions.set_weights(cpu_data.data_ptr()) + transitions.calc_grad = transition_params.requires_grad + transitions.zero_grad() + + def process(b: int) -> None: + # Create emission graph: + emissions = gtn.linear_graph(T, C, inputs.requires_grad) + cpu_data = inputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + target = make_chain_graph(targets[b]) + target.arc_sort(True) + + # Create token tot grapheme decomposition graph + tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) + tokens_target.arc_sort() + + # Create alignment graph: + aligments = gtn.project_input( + gtn.remove(gtn.compose(tokens, tokens_target)) + ) + aligments.arc_sort() + + # Add transitions scores: + if transitions is not None: + aligments = gtn.intersect(transitions, aligments) + aligments.arc_sort() + + loss = gtn.forward_score(gtn.intersect(emissions, aligments)) + + # Normalize if needed: + if transitions is not None: + norm = gtn.forward_score(gtn.intersect(emissions, transitions)) + loss = gtn.subtract(loss, norm) + + losses[b] = gtn.negate(loss) + + # Save for backward: + if emissions.calc_grad: + emissions_graphs[b] = emissions + + gtn.parallel_for(process, range(B)) + + ctx.graphs = (losses, emissions_graphs, transitions) + ctx.input_shape = inputs.shape + + # Optionally reduce by target length + if reduction == "mean": + scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] + else: + scales = [1.0] * B + + ctx.scales = scales + + loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) + return torch.mean(loss.to(inputs.device)) + + @staticmethod + def backward(ctx, grad_output) -> Tuple: + losses, emissions_graphs, transitions = ctx.graphs + scales = ctx.scales + + B, T, C = ctx.input_shape + calc_emissions = ctx.needs_input_grad[0] + input_grad = torch.empty((B, T, C)) if calc_emissions else None + + def process(b: int) -> None: + scale = make_scalar_graph(scales[b]) + gtn.backward(losses[b], scale) + emissions = emissions_graphs[b] + if calc_emissions: + grad = emissions.grad().weights_to_numpy() + input_grad[b] = torch.tensor(grad).view(1, T, C) + + gtn.parallel_for(process, range(B)) + + if calc_emissions: + input_grad = input_grad.to(grad_output.device) + input_grad *= grad_output / B + + if ctx.needs_input_grad[4]: + grad = transitions.grad().weights_to_numpy() + transition_grad = torch.tensor(grad).to(grad_output.device) + transition_grad *= grad_output / B + else: + transition_grad = None + + return ( + input_grad, + None, # target + None, # tokens + None, # lexicon + transition_grad, # transition params + None, # transitions graph + None, + ) + + +TransducerLoss = TransducerLossFunction.apply + + +class Transducer(nn.Module): + def __init__( + self, + tokens: List, + graphemes_to_idx: Dict, + ngram: int = 0, + transitions: str = None, + blank: str = "none", + allow_repeats: bool = True, + reduction: str = "none", + ) -> None: + """A generic transducer loss function. + + Args: + tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) + representing the output tokens of the model (e.g. letters, + word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] + could be a list of sub-word tokens. + graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. + "a", "b", ..) to their corresponding integer index. + ngram (int) : Order of the token-level transition model. If `ngram=0` + then no transition model is used. + blank (string) : Specifies the usage of blank token + 'none' - do not use blank token + 'optional' - allow an optional blank inbetween tokens + 'forced' - force a blank inbetween tokens (also referred to as garbage token) + allow_repeats (boolean) : If false, then we don't allow paths with + consecutive tokens in the alignment graph. This keeps the graph + unambiguous in the sense that the same input cannot transduce to + different outputs. + """ + super().__init__() + if blank not in ["optional", "forced", "none"]: + raise ValueError( + "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" + ) + self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) + self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) + self.ngram = ngram + if ngram > 0 and transitions is not None: + raise ValueError("Only one of ngram and transitions may be specified") + + if ngram > 0: + transitions = make_transitions_graph( + ngram, len(tokens) + int(blank != "none"), True + ) + + if transitions is not None: + self.transitions = transitions + self.transitions.arc_sort() + self.transitions_params = nn.Parameter( + torch.zeros(self.transitions.num_arcs()) + ) + else: + self.transitions = None + self.transitions_params = None + self.reduction = reduction + + def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: + TransducerLoss( + inputs, + targets, + self.tokens, + self.lexicon, + self.transitions_params, + self.transitions, + self.reduction, + ) + + def viterbi(self, outputs: Tensor) -> List[Tensor]: + B, T, C = outputs.shape + + if self.transitions is not None: + cpu_data = self.transition_params.cpu().contiguous() + self.transitions.set_weights(cpu_data.data_ptr()) + self.transitions.calc_grad = False + + self.tokens.arc_sort() + + paths = [None] * B + + def process(b: int) -> None: + emissions = gtn.linear_graph(T, C, False) + cpu_data = outputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + + if self.transitions is not None: + full_graph = gtn.intersect(emissions, self.transitions) + else: + full_graph = emissions + + # Find the best path and remove back-off arcs: + path = gtn.remove(gtn.viterbi_path(full_graph)) + + # Left compose the viterbi path with the "aligment to token" + # transducer to get the outputs: + path = gtn.compose(path, self.tokens) + + # When there are ambiguous paths (allow_repeats is true), we take + # the shortest: + path = gtn.viterbi_path(path) + path = gtn.remove(gtn.project_output(path)) + paths[b] = path.labels_to_list() + + gtn.parallel_for(process, range(B)) + predictions = [torch.IntTensor(path) for path in paths] + return predictions + + +def load_transducer_loss( + num_features: int, + ngram: int, + tokens: str, + lexicon: str, + transitions: str, + blank: str, + allow_repeats: bool, + prepend_wordsep: bool = False, + use_words: bool = False, + data_dir: Optional[Union[str, Path]] = None, + reduction: str = "mean", +) -> Tuple[Transducer, int]: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + if transitions is not None: + transitions = gtn.load(str(processed_path / transitions)) + + preprocessor = Preprocessor( + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + ) + + num_tokens = preprocessor.num_tokens + + criterion = Transducer( + preprocessor.tokens, + preprocessor.graphemes_to_index, + ngram=ngram, + transitions=transitions, + blank=blank, + allow_repeats=allow_repeats, + reduction=reduction, + ) + + return criterion, num_tokens + int(blank != "none") diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py new file mode 100644 index 0000000..9febc88 --- /dev/null +++ b/text_recognizer/networks/transformer/__init__.py @@ -0,0 +1,3 @@ +"""Transformer modules.""" +from .positional_encoding import PositionalEncoding +from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py new file mode 100644 index 0000000..cce1ecc --- /dev/null +++ b/text_recognizer/networks/transformer/attention.py @@ -0,0 +1,93 @@ +"""Implementes the attention module for the transformer.""" +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class MultiHeadAttention(nn.Module): + """Implementation of multihead attention.""" + + def __init__( + self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.fc_q = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_k = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_v = nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=False + ) + self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) + + self._init_weights() + + self.dropout = nn.Dropout(p=dropout_rate) + + def _init_weights(self) -> None: + nn.init.normal_( + self.fc_q.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_k.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.normal_( + self.fc_v.weight, + mean=0, + std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + ) + nn.init.xavier_normal_(self.fc_out.weight) + + def scaled_dot_product_attention( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tensor: + """Calculates the scaled dot product attention.""" + + # Compute the energy. + energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( + query.shape[-1] + ) + + # If we have a mask for padding some inputs. + if mask is not None: + energy = energy.masked_fill(mask == 0, -np.inf) + + # Compute the attention from the energy. + attention = torch.softmax(energy, dim=3) + + out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) + out = rearrange(out, "b head l v -> b l (head v)") + return out, attention + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + """Forward pass for computing the multihead attention.""" + # Get the query, key, and value tensor. + query = rearrange( + self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads + ) + key = rearrange( + self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads + ) + value = rearrange( + self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + ) + + out, attention = self.scaled_dot_product_attention(query, key, value, mask) + + out = self.fc_out(out) + out = self.dropout(out) + return out, attention diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py new file mode 100644 index 0000000..1ba5537 --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -0,0 +1,32 @@ +"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class PositionalEncoding(nn.Module): + """Encodes a sense of distance or time for transformer networks.""" + + def __init__( + self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 + ) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout_rate) + self.max_len = max_len + + pe = torch.zeros(max_len, hidden_dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + """Encodes the tensor with a postional embedding.""" + x = x + self.pe[:, : x.shape[1]] + return self.dropout(x) diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py new file mode 100644 index 0000000..dd180c4 --- /dev/null +++ b/text_recognizer/networks/transformer/transformer.py @@ -0,0 +1,264 @@ +"""Transfomer module.""" +import copy +from typing import Dict, Optional, Type, Union + +import numpy as np +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + +from text_recognizer.networks.transformer.attention import MultiHeadAttention +from text_recognizer.networks.util import activation_function + + +class GEGLU(nn.Module): + """GLU activation for improving feedforward activations.""" + + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x: Tensor) -> Tensor: + """Forward propagation.""" + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) + + +class _IntraLayerConnection(nn.Module): + """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" + + def __init__(self, dropout_rate: float, hidden_dim: int) -> None: + super().__init__() + self.norm = nn.LayerNorm(normalized_shape=hidden_dim) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, src: Tensor, residual: Tensor) -> Tensor: + return self.norm(self.dropout(src) + residual) + + +class _ConvolutionalLayer(nn.Module): + def __init__( + self, + hidden_dim: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + in_projection = ( + nn.Sequential( + nn.Linear(hidden_dim, expansion_dim), activation_function(activation) + ) + if activation != "glu" + else GEGLU(hidden_dim, expansion_dim) + ) + + self.layer = nn.Sequential( + in_projection, + nn.Dropout(p=dropout_rate), + nn.Linear(in_features=expansion_dim, out_features=hidden_dim), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layer(x) + + +class EncoderLayer(nn.Module): + """Transfomer encoding layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through the encoder.""" + # First block. + # Multi head attention. + out, _ = self.self_attention(src, src, src, mask) + + # Add & norm. + out = self.block1(out, src) + + # Second block. + # Apply 1D-convolution. + cnn_out = self.cnn(out) + + # Add & norm. + out = self.block2(cnn_out, out) + + return out + + +class Encoder(nn.Module): + """Transfomer encoder module.""" + + def __init__( + self, + num_layers: int, + encoder_layer: Type[nn.Module], + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.norm = norm + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + """Forward pass through all encoder layers.""" + for layer in self.layers: + src = layer(src, src_mask) + + if self.norm is not None: + src = self.norm(src) + + return src + + +class DecoderLayer(nn.Module): + """Transfomer decoder layer.""" + + def __init__( + self, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float = 0.0, + activation: str = "relu", + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) + self.multihead_attention = MultiHeadAttention( + hidden_dim, num_heads, dropout_rate + ) + self.cnn = _ConvolutionalLayer( + hidden_dim, expansion_dim, dropout_rate, activation + ) + self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) + self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass of the layer.""" + out, _ = self.self_attention(trg, trg, trg, trg_mask) + trg = self.block1(out, trg) + + out, _ = self.multihead_attention(trg, memory, memory, memory_mask) + trg = self.block2(out, trg) + + out = self.cnn(trg) + out = self.block3(out, trg) + + return out + + +class Decoder(nn.Module): + """Transfomer decoder module.""" + + def __init__( + self, + decoder_layer: Type[nn.Module], + num_layers: int, + norm: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + trg: Tensor, + memory: Tensor, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the decoder.""" + for layer in self.layers: + trg = layer(trg, memory, trg_mask, memory_mask) + + if self.norm is not None: + trg = self.norm(trg) + + return trg + + +class Transformer(nn.Module): + """Transformer network.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + num_heads: int, + expansion_dim: int, + dropout_rate: float, + activation: str = "relu", + ) -> None: + super().__init__() + + # Configure encoder. + encoder_norm = nn.LayerNorm(hidden_dim) + encoder_layer = EncoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) + + # Configure decoder. + decoder_norm = nn.LayerNorm(hidden_dim) + decoder_layer = DecoderLayer( + hidden_dim, num_heads, expansion_dim, dropout_rate, activation + ) + self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src: Tensor, + trg: Tensor, + src_mask: Optional[Tensor] = None, + trg_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + ) -> Tensor: + """Forward pass through the transformer.""" + if src.shape[0] != trg.shape[0]: + print(trg.shape) + raise RuntimeError("The batch size of the src and trg must be the same.") + if src.shape[2] != trg.shape[2]: + raise RuntimeError( + "The number of features for the src and trg must be the same." + ) + + memory = self.encoder(src, src_mask) + output = self.decoder(trg, memory, trg_mask, memory_mask) + return output diff --git a/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py new file mode 100644 index 0000000..510910f --- /dev/null +++ b/text_recognizer/networks/unet.py @@ -0,0 +1,255 @@ +"""UNet for segmentation.""" +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +class _ConvBlock(nn.Module): + """Modified UNet convolutional block with dilation.""" + + def __init__( + self, + channels: List[int], + activation: str, + num_groups: int, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, + ) -> None: + super().__init__() + self.channels = channels + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.num_groups = num_groups + self.activation = activation_function(activation) + self.block = self._configure_block() + self.residual_conv = nn.Sequential( + nn.Conv2d( + self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 + ), + self.activation, + ) + + def _configure_block(self) -> nn.Sequential: + block = [] + for i in range(len(self.channels) - 1): + block += [ + nn.Dropout(p=self.dropout_rate), + nn.GroupNorm(self.num_groups, self.channels[i]), + self.activation, + nn.Conv2d( + self.channels[i], + self.channels[i + 1], + kernel_size=self.kernel_size, + padding=self.padding, + stride=1, + dilation=self.dilation, + ), + ] + + return nn.Sequential(*block) + + def forward(self, x: Tensor) -> Tensor: + """Apply the convolutional block.""" + residual = self.residual_conv(x) + return self.block(x) + residual + + +class _DownSamplingBlock(nn.Module): + """Basic down sampling block.""" + + def __init__( + self, + channels: List[int], + activation: str, + num_groups: int, + pooling_kernel: Union[int, bool] = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, + ) -> None: + super().__init__() + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) + self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Return the convolutional block output and a down sampled tensor.""" + x = self.conv_block(x) + x_down = self.down_sampling(x) if self.down_sampling is not None else x + + return x_down, x + + +class _UpSamplingBlock(nn.Module): + """The upsampling block of the UNet.""" + + def __init__( + self, + channels: List[int], + activation: str, + num_groups: int, + scale_factor: int = 2, + dropout_rate: float = 0.1, + kernel_size: int = 3, + dilation: int = 1, + padding: int = 0, + ) -> None: + super().__init__() + self.conv_block = _ConvBlock( + channels, + activation, + num_groups, + dropout_rate, + kernel_size, + dilation, + padding, + ) + self.up_sampling = nn.Upsample( + scale_factor=scale_factor, mode="bilinear", align_corners=True + ) + + def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: + """Apply the up sampling and convolutional block.""" + x = self.up_sampling(x) + if x_skip is not None: + x = torch.cat((x, x_skip), dim=1) + return self.conv_block(x) + + +class UNet(nn.Module): + """UNet architecture.""" + + def __init__( + self, + in_channels: int = 1, + base_channels: int = 64, + num_classes: int = 3, + depth: int = 4, + activation: str = "relu", + num_groups: int = 8, + dropout_rate: float = 0.1, + pooling_kernel: int = 2, + scale_factor: int = 2, + kernel_size: Optional[List[int]] = None, + dilation: Optional[List[int]] = None, + padding: Optional[List[int]] = None, + ) -> None: + super().__init__() + self.depth = depth + self.num_groups = num_groups + + if kernel_size is not None and dilation is not None and padding is not None: + if ( + len(kernel_size) != depth + and len(dilation) != depth + and len(padding) != depth + ): + raise RuntimeError( + "Length of convolutional parameters does not match the depth." + ) + self.kernel_size = kernel_size + self.padding = padding + self.dilation = dilation + + else: + self.kernel_size = [3] * depth + self.padding = [1] * depth + self.dilation = [1] * depth + + self.dropout_rate = dropout_rate + self.conv = nn.Conv2d( + in_channels, base_channels, kernel_size=3, stride=1, padding=1 + ) + + channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] + self.encoder_blocks = self._configure_down_sampling_blocks( + channels, activation, pooling_kernel + ) + self.decoder_blocks = self._configure_up_sampling_blocks( + channels, activation, scale_factor + ) + + self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) + + def _configure_down_sampling_blocks( + self, channels: List[int], activation: str, pooling_kernel: int + ) -> nn.ModuleList: + blocks = nn.ModuleList([]) + for i in range(len(channels) - 1): + pooling_kernel = pooling_kernel if i < self.depth - 1 else False + dropout_rate = self.dropout_rate if i < 0 else 0 + blocks += [ + _DownSamplingBlock( + [channels[i], channels[i + 1], channels[i + 1]], + activation, + self.num_groups, + pooling_kernel, + dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], + ) + ] + + return blocks + + def _configure_up_sampling_blocks( + self, channels: List[int], activation: str, scale_factor: int, + ) -> nn.ModuleList: + channels.reverse() + self.kernel_size.reverse() + self.dilation.reverse() + self.padding.reverse() + return nn.ModuleList( + [ + _UpSamplingBlock( + [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], + activation, + self.num_groups, + scale_factor, + self.dropout_rate, + self.kernel_size[i], + self.dilation[i], + self.padding[i], + ) + for i in range(len(channels) - 2) + ] + ) + + def _encode(self, x: Tensor) -> List[Tensor]: + x_skips = [] + for block in self.encoder_blocks: + x, x_skip = block(x) + x_skips.append(x_skip) + return x_skips + + def _decode(self, x_skips: List[Tensor]) -> Tensor: + x = x_skips[-1] + for i, block in enumerate(self.decoder_blocks): + x = block(x, x_skips[-(i + 2)]) + return x + + def forward(self, x: Tensor) -> Tensor: + """Forward pass with the UNet model.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + x = self.conv(x) + x_skips = self._encode(x) + x = self._decode(x_skips) + return self.head(x) diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py new file mode 100644 index 0000000..131a6b4 --- /dev/null +++ b/text_recognizer/networks/util.py @@ -0,0 +1,89 @@ +"""Miscellaneous neural network functionality.""" +import importlib +from pathlib import Path +from typing import Dict, Tuple, Type + +from einops import rearrange +from loguru import logger +import torch +from torch import nn + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) + # Preform the sliding window, unsqueeze as the channel dimesion is lost. + c = images.shape[1] + patches = unfold(images) + patches = rearrange( + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], + ) + return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["elu", nn.ELU(inplace=True)], + ["gelu", nn.GELU()], + ["glu", nn.GLU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] + + +def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: + """Loads a backbone network.""" + network_module = importlib.import_module("text_recognizer.networks") + backbone_ = getattr(network_module, backbone) + + if "pretrained" in backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( + "pretrained" + ) + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + freeze = False + if "freeze" in backbone_args and backbone_args["freeze"] is True: + backbone_args.pop("freeze") + freeze = True + network_args = backbone_args + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if freeze: + for params in backbone.parameters(): + params.requires_grad = False + else: + backbone_ = getattr(network_module, backbone) + backbone = backbone_(**backbone_args) + + if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: + backbone = nn.Sequential( + *list(backbone.children())[:][: -backbone_args["remove_layers"]] + ) + + return backbone diff --git a/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py new file mode 100644 index 0000000..efb3701 --- /dev/null +++ b/text_recognizer/networks/vit.py @@ -0,0 +1,150 @@ +"""A Vision Transformer. + +Inspired by: +https://openreview.net/pdf?id=YicbFdNTTy + +""" +from typing import Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import Transformer + + +class ViT(nn.Module): + """Transfomer for image to sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + expansion_dim: int, + patch_dim: Tuple[int, int], + image_size: Tuple[int, int], + dropout_rate: float, + trg_pad_index: int, + max_len: int, + activation: str = "gelu", + ) -> None: + super().__init__() + + self.trg_pad_index = trg_pad_index + self.patch_dim = patch_dim + self.num_patches = image_size[-1] // self.patch_dim[1] + + # Encoder + self.patch_to_embedding = nn.Linear( + self.patch_dim[0] * self.patch_dim[1], hidden_dim + ) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) + self.character_embedding = nn.Embedding(vocab_size, hidden_dim) + self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.dropout = nn.Dropout(dropout_rate) + self._init() + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _init(self) -> None: + nn.init.normal_(self.character_embedding.weight, std=0.02) + # nn.init.normal_(self.pos_embedding.weight, std=0.02) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tensor: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: A input src to the transformer. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + + patches = rearrange( + src, + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=self.patch_dim[0], + p2=self.patch_dim[1], + ) + + # From patches to encoded sequence. + x = self.patch_to_embedding(patches) + b, n, _ = x.shape + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + return x + + def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + + """ + _, n = trg.shape + trg = self.character_embedding(trg.long()) + trg += self.pos_embedding[:, :n] + return trg + + def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(h, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + h = self.extract_image_features(x) + logits = self.decode_image_features(h, trg) + return logits diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..c673d96 --- /dev/null +++ b/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,150 @@ +"""A VQ-Transformer for image to text recognition.""" +from typing import Dict, Optional, Tuple + +from einops import rearrange, repeat +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.util import configure_backbone +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class VQTransformer(nn.Module): + """VQ+Transfomer for image to character sequence prediction.""" + + def __init__( + self, + num_encoder_layers: int, + num_decoder_layers: int, + hidden_dim: int, + vocab_size: int, + num_heads: int, + adaptive_pool_dim: Tuple, + expansion_dim: int, + dropout_rate: float, + trg_pad_index: int, + max_len: int, + backbone: str, + backbone_args: Optional[Dict] = None, + activation: str = "gelu", + ) -> None: + super().__init__() + + # Configure vector quantized backbone. + self.backbone = configure_backbone(backbone, backbone_args) + self.conv = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), + nn.ReLU(inplace=True), + ) + + # Configure embeddings for Transformer network. + self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + nn.init.normal_(self.character_embedding.weight, std=0.02) + + self.adaptive_pool = ( + nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + ) + + self.transformer = Transformer( + num_encoder_layers, + num_decoder_layers, + hidden_dim, + num_heads, + expansion_dim, + dropout_rate, + activation, + ) + + self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + + def _create_trg_mask(self, trg: Tensor) -> Tensor: + # Move this outside the transformer. + trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask + + def encoder(self, src: Tensor) -> Tensor: + """Forward pass with the encoder of the transformer.""" + return self.transformer.encoder(src) + + def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: + """Forward pass with the decoder of the transformer + classification head.""" + return self.head( + self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + ) + + def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: + """Extracts image features with a backbone neural network. + + It seem like the winning idea was to swap channels and width dimension and collapse + the height dimension. The transformer is learning like a baby with this implementation!!! :D + Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + + Args: + src (Tensor): Input tensor. + + Returns: + Tensor: The input src to the transformer and the vq loss. + + """ + # If batch dimension is missing, it needs to be added. + if len(src.shape) < 4: + src = src[(None,) * (4 - len(src.shape))] + src, vq_loss = self.backbone.encode(src) + # src = self.backbone.decoder.res_block(src) + src = self.conv(src) + + if self.adaptive_pool is not None: + src = rearrange(src, "b c h w -> b w c h") + src = self.adaptive_pool(src) + src = src.squeeze(3) + else: + src = rearrange(src, "b c h w -> b (w h) c") + + b, t, _ = src.shape + + src += self.src_position_embedding[:, :t] + + return src, vq_loss + + def target_embedding(self, trg: Tensor) -> Tensor: + """Encodes target tensor with embedding and postion. + + Args: + trg (Tensor): Target tensor. + + Returns: + Tensor: Encoded target tensor. + + """ + trg = self.character_embedding(trg.long()) + trg = self.trg_position_encoding(trg) + return trg + + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: + """Takes images features from the backbone and decodes them with the transformer.""" + trg_mask = self._create_trg_mask(trg) + trg = self.target_embedding(trg) + out = self.transformer(image_features, trg, trg_mask=trg_mask) + + logits = self.head(out) + return logits + + def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: + """Forward pass with CNN transfomer.""" + image_features, vq_loss = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) + return logits, vq_loss diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py new file mode 100644 index 0000000..763953c --- /dev/null +++ b/text_recognizer/networks/vqvae/__init__.py @@ -0,0 +1,5 @@ +"""VQ-VAE module.""" +from .decoder import Decoder +from .encoder import Encoder +from .vector_quantizer import VectorQuantizer +from .vqvae import VQVAE diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py new file mode 100644 index 0000000..8847aba --- /dev/null +++ b/text_recognizer/networks/vqvae/decoder.py @@ -0,0 +1,133 @@ +"""CNN decoder for the VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class Decoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + upsampling: Optional[List[List[int]]] = None, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.upsampling = upsampling + + self.res_block = nn.ModuleList([]) + self.upsampling_block = nn.ModuleList([]) + + self.embedding_dim = embedding_dim + activation = activation_function(activation) + + # Configure encoder. + self.decoder = self._build_decoder( + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + ) + + def _build_decompression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + activation, + ) + ) + + if i < len(self.upsampling): + modules.append(nn.Upsample(size=self.upsampling[i]),) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + modules.extend( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 + ), + nn.Tanh(), + ) + ) + + return modules + + def _build_decoder( + self, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + + self.res_block.append( + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + ) + + # Bottleneck module. + self.res_block.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[0], channels[0], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + # Decompression module + self.upsampling_block.extend( + self._build_decompression_block( + channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ) + ) + + self.res_block = nn.Sequential(*self.res_block) + self.upsampling_block = nn.Sequential(*self.upsampling_block) + + return nn.Sequential(self.res_block, self.upsampling_block) + + def forward(self, z_q: Tensor) -> Tensor: + """Reconstruct input from given codes.""" + x_reconstruction = self.decoder(z_q) + return x_reconstruction diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py new file mode 100644 index 0000000..d3adac5 --- /dev/null +++ b/text_recognizer/networks/vqvae/encoder.py @@ -0,0 +1,147 @@ +"""CNN encoder for the VQ-VAE.""" +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer + + +class _ResidualBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + ) -> None: + super().__init__() + self.block = [ + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), + ] + + if dropout is not None: + self.block.append(dropout) + + self.block = nn.Sequential(*self.block) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) + + +class Encoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = beta + activation = activation_function(activation) + + # Configure encoder. + self.encoder = self._build_encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, + ) + + # Configure Vector Quantizer. + self.vector_quantizer = VectorQuantizer( + self.num_embeddings, self.embedding_dim, self.beta + ) + + def _build_compression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for out_channels, kernel_size, stride in configuration: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ), + activation, + ) + ) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + return modules + + def _build_encoder( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + encoder = nn.ModuleList([]) + + # compression module + encoder.extend( + self._build_compression_block( + in_channels, channels, kernel_sizes, strides, activation, dropout + ) + ) + + # Bottleneck module. + encoder.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[-1], channels[-1], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + encoder.append( + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + ) + + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input into a discrete representation.""" + z_e = self.encoder(x) + z_q, vq_loss = self.vector_quantizer(z_e) + return z_q, vq_loss diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py new file mode 100644 index 0000000..f92c7ee --- /dev/null +++ b/text_recognizer/networks/vqvae/vector_quantizer.py @@ -0,0 +1,119 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py + +""" + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +class VectorQuantizer(nn.Module): + """The codebook that contains quantized vectors.""" + + def __init__( + self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 + ) -> None: + super().__init__() + self.K = num_embeddings + self.D = embedding_dim + self.beta = beta + + self.embedding = nn.Embedding(self.K, self.D) + + # Initialize the codebook. + nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) + + def discretization_bottleneck(self, latent: Tensor) -> Tensor: + """Computes the code nearest to the latent representation. + + First we compute the posterior categorical distribution, and then map + the latent representation to the nearest element of the embedding. + + Args: + latent (Tensor): The latent representation. + + Shape: + - latent :math:`(B x H x W, D)` + + Returns: + Tensor: The quantized embedding vector. + + """ + # Store latent shape. + b, h, w, d = latent.shape + + # Flatten the latent representation to 2D. + latent = rearrange(latent, "b h w d -> (b h w) d") + + # Compute the L2 distance between the latents and the embeddings. + l2_distance = ( + torch.sum(latent ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * latent @ self.embedding.weight.t() + ) # [BHW x K] + + # Find the embedding k nearest to each latent. + encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] + + # Convert to one-hot encodings, aka discrete bottleneck. + one_hot_encoding = torch.zeros( + encoding_indices.shape[0], self.K, device=latent.device + ) + one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] + + # Embedding quantization. + quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] + quantized_latent = rearrange( + quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w + ) + + return quantized_latent + + def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + """Vector Quantization loss. + + The vector quantization algorithm allows us to create a codebook. The VQ + algorithm works by moving the embedding vectors towards the encoder outputs. + + The embedding loss moves the embedding vector towards the encoder outputs. The + .detach() works as the stop gradient (sg) described in the paper. + + Because the volume of the embedding space is dimensionless, it can arbitarily + grow if the embeddings are not trained as fast as the encoder parameters. To + mitigate this, a commitment loss is added in the second term which makes sure + that the encoder commits to an embedding and that its output does not grow. + + Args: + latent (Tensor): The encoder output. + quantized_latent (Tensor): The quantized latent. + + Returns: + Tensor: The combinded VQ loss. + + """ + embedding_loss = F.mse_loss(quantized_latent, latent.detach()) + commitment_loss = F.mse_loss(quantized_latent.detach(), latent) + return embedding_loss + self.beta * commitment_loss + + def forward(self, latent: Tensor) -> Tensor: + """Forward pass that returns the quantized vector and the vq loss.""" + # Rearrange latent representation s.t. the hidden dim is at the end. + latent = rearrange(latent, "b d h w -> b h w d") + + # Maps latent to the nearest code in the codebook. + quantized_latent = self.discretization_bottleneck(latent) + + loss = self.vq_loss(latent, quantized_latent) + + # Add residue to the quantized latent. + quantized_latent = latent + (quantized_latent - latent).detach() + + # Rearrange the quantized shape back to the original shape. + quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") + + return quantized_latent, loss diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss diff --git a/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py new file mode 100644 index 0000000..b767778 --- /dev/null +++ b/text_recognizer/networks/wide_resnet.py @@ -0,0 +1,221 @@ +"""Wide Residual CNN.""" +from functools import partial +from typing import Callable, Dict, List, Optional, Type, Union + +from einops.layers.torch import Reduce +import numpy as np +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """Helper function for a 3x3 2d convolution.""" + return nn.Conv2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + + +def conv_init(module: Type[nn.Module]) -> None: + """Initializes the weights for convolution and batchnorms.""" + classname = module.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) + nn.init.constant_(module.bias, 0) + elif classname.find("BatchNorm") != -1: + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + +class WideBlock(nn.Module): + """Block used in WideResNet.""" + + def __init__( + self, + in_planes: int, + out_planes: int, + dropout_rate: float, + stride: int = 1, + activation: str = "relu", + ) -> None: + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.dropout_rate = dropout_rate + self.stride = stride + self.activation = activation_function(activation) + + # Build blocks. + self.blocks = nn.Sequential( + nn.BatchNorm2d(self.in_planes), + self.activation, + conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), + nn.Dropout(p=self.dropout_rate), + nn.BatchNorm2d(self.out_planes), + self.activation, + conv3x3( + in_planes=self.out_planes, + out_planes=self.out_planes, + stride=self.stride, + ), + ) + + self.shortcut = ( + nn.Sequential( + nn.Conv2d( + in_channels=self.in_planes, + out_channels=self.out_planes, + kernel_size=1, + stride=self.stride, + bias=False, + ), + ) + if self._apply_shortcut + else None + ) + + @property + def _apply_shortcut(self) -> bool: + """If shortcut should be applied or not.""" + return self.stride != 1 or self.in_planes != self.out_planes + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + residual = x + if self._apply_shortcut: + residual = self.shortcut(x) + x = self.blocks(x) + x += residual + return x + + +class WideResidualNetwork(nn.Module): + """WideResNet for character predictions. + + Can be used for classification or encoding of images to a latent vector. + + """ + + def __init__( + self, + in_channels: int = 1, + in_planes: int = 16, + num_classes: int = 80, + depth: int = 16, + width_factor: int = 10, + dropout_rate: float = 0.0, + num_layers: int = 3, + block: Type[nn.Module] = WideBlock, + num_stages: Optional[List[int]] = None, + activation: str = "relu", + use_decoder: bool = True, + ) -> None: + """The initialization of the WideResNet. + + Args: + in_channels (int): Number of input channels. Defaults to 1. + in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. + num_classes (int): Number of classes. Defaults to 80. + depth (int): Set the number of blocks to use. Defaults to 16. + width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. + dropout_rate (float): The dropout rate. Defaults to 0.0. + num_layers (int): Number of layers of blocks. Defaults to 3. + block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. + num_stages (List[int]): If given, will use these channel values. Defaults to None. + activation (str): Name of the activation to use. Defaults to "relu". + use_decoder (bool): If True, the network output character predictions, if False, the network outputs a + latent vector. Defaults to True. + + Raises: + RuntimeError: If the depth is not of the size `6n+4`. + + """ + + super().__init__() + if (depth - 4) % 6 != 0: + raise RuntimeError("Wide-resnet depth should be 6n+4") + self.in_channels = in_channels + self.in_planes = in_planes + self.num_classes = num_classes + self.num_blocks = (depth - 4) // 6 + self.width_factor = width_factor + self.num_layers = num_layers + self.block = block + self.dropout_rate = dropout_rate + self.activation = activation_function(activation) + + if num_stages is None: + self.num_stages = [self.in_planes] + [ + self.in_planes * 2 ** n * self.width_factor + for n in range(self.num_layers) + ] + else: + self.num_stages = [self.in_planes] + num_stages + + self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) + self.strides = [1] + [2] * (self.num_layers - 1) + + self.encoder = nn.Sequential( + conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), + *[ + self._configure_wide_layer( + in_planes=in_planes, + out_planes=out_planes, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip( + self.num_stages, self.strides + ) + ], + ) + + self.decoder = ( + nn.Sequential( + nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), + self.activation, + Reduce("b c h w -> b c", "mean"), + nn.Linear( + in_features=self.num_stages[-1][-1], out_features=self.num_classes + ), + ) + if use_decoder + else None + ) + + # self.apply(conv_init) + + def _configure_wide_layer( + self, in_planes: int, out_planes: int, stride: int, activation: str + ) -> List: + strides = [stride] + [1] * (self.num_blocks - 1) + planes = [out_planes] * len(strides) + planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) + return nn.Sequential( + *[ + self.block( + in_planes=in_planes, + out_planes=out_planes, + dropout_rate=self.dropout_rate, + stride=stride, + activation=activation, + ) + for (in_planes, out_planes), stride in zip(planes, strides) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass.""" + if len(x.shape) < 4: + x = x[(None,) * int(4 - len(x.shape))] + x = self.encoder(x) + if self.decoder is not None: + x = self.decoder(x) + return x diff --git a/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py new file mode 100644 index 0000000..aa39662 --- /dev/null +++ b/text_recognizer/paragraph_text_recognizer.py @@ -0,0 +1,153 @@ +"""Full model. + +Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the +each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text +in each region. +""" +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np +import torch + +from text_recognizer.models import SegmentationModel, TransformerModel +from text_recognizer.util import read_image + + +class ParagraphTextRecognizor: + """Given an image of a single handwritten character, recognizes it.""" + + def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None: + self._line_predictor = TransformerModel(**line_predictor_args) + self._line_detector = SegmentationModel(**line_detector_args) + self._line_detector.eval() + self._line_predictor.eval() + + def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple: + """Takes an image and returns all text within it.""" + image = ( + read_image(image_or_filename) + if isinstance(image_or_filename, str) + else image_or_filename + ) + + line_region_crops = self._get_line_region_crops(image) + processed_line_region_crops = [ + self._process_image_for_line_predictor(image=crop) + for crop in line_region_crops + ] + line_region_strings = [ + self.line_predictor_model.predict_on_image(crop)[0] + for crop in processed_line_region_crops + ] + + return " ".join(line_region_strings), line_region_crops + + def _get_line_region_crops( + self, image: np.ndarray, min_crop_len_factor: float = 0.02 + ) -> List[np.ndarray]: + """Returns all the crops of text lines in a square image.""" + processed_image, scale_down_factor = self._process_image_for_line_detector( + image + ) + line_segmentation = self._line_detector.predict_on_image(processed_image) + bounding_boxes = _find_line_bounding_boxes(line_segmentation) + + bounding_boxes = (bounding_boxes * scale_down_factor).astype(int) + + min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1])) + line_region_crops = [ + image[y : y + h, x : x + w] + for x, y, w, h in bounding_boxes + if w >= min_crop_len and h >= min_crop_len + ] + return line_region_crops + + def _process_image_for_line_detector( + self, image: np.ndarray + ) -> Tuple[np.ndarray, float]: + """Convert uint8 image to float image with black background with shape self._line_detector.image_shape.""" + resized_image, scale_down_factor = _resize_image_for_line_detector( + image=image, max_shape=self._line_detector.image_shape + ) + resized_image = (1.0 - resized_image / 255).astype("float32") + return resized_image, scale_down_factor + + def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray: + """Preprocessing of image before feeding it to the LinePrediction model. + + Convert uint8 image to float image with black background with shape + self._line_predictor.image_shape while maintaining the image aspect ratio. + + Args: + image (np.ndarray): Crop of text line. + + Returns: + np.ndarray: Processed crop for feeding line predictor. + """ + expected_shape = self._line_detector.image_shape + scale_factor = (np.array(expected_shape) / np.array(image.shape)).min() + scaled_image = cv2.resize( + image, + dsize=None, + fx=scale_factor, + fy=scale_factor, + interpolation=cv2.INTER_AREA, + ) + + pad_with = ( + (0, expected_shape[0] - scaled_image.shape[0]), + (0, expected_shape[1] - scaled_image.shape[1]), + ) + + padded_image = np.pad( + scaled_image, pad_with=pad_with, mode="constant", constant_values=255 + ) + return 1 - padded_image / 255 + + +def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray: + """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels.""" + + def _find_line_bounding_boxes_in_channel( + line_segmentation_channel: np.ndarray, + ) -> np.ndarray: + line_segmentation_image = cv2.dilate( + line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1 + ) + line_activation_image = (line_segmentation_image * 255).astype("uint8") + line_activation_image = cv2.threshold( + line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU + )[1] + + bounding_cnts, _ = cv2.findContours( + line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts]) + + bounding_boxes = np.concatenate( + [ + _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i]) + for i in [1, 2] + ], + axis=0, + ) + + return bounding_boxes[np.argsort(bounding_boxes[:, 1])] + + +def _resize_image_for_line_detector( + image: np.ndarray, max_shape: Tuple[int, int] +) -> Tuple[np.ndarray, float]: + """Resize the image to less than the max_shape while maintaining the aspect ratio.""" + scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape)) + if scale_down_factor == 1: + return image.copy(), scale_down_factor + resize_image = cv2.resize( + image, + dsize=None, + fx=1 / scale_down_factor, + fy=1 / scale_down_factor, + interpolation=cv2.INTER_AREA, + ) + return resize_image, scale_down_factor diff --git a/text_recognizer/tests/__init__.py b/text_recognizer/tests/__init__.py new file mode 100644 index 0000000..18ff212 --- /dev/null +++ b/text_recognizer/tests/__init__.py @@ -0,0 +1 @@ +"""Test modules for the text text recognizer.""" diff --git a/text_recognizer/tests/support/__init__.py b/text_recognizer/tests/support/__init__.py new file mode 100644 index 0000000..a265ede --- /dev/null +++ b/text_recognizer/tests/support/__init__.py @@ -0,0 +1,2 @@ +"""Support file modules.""" +from .create_emnist_support_files import create_emnist_support_files diff --git a/text_recognizer/tests/support/create_emnist_lines_support_files.py b/text_recognizer/tests/support/create_emnist_lines_support_files.py new file mode 100644 index 0000000..9abe143 --- /dev/null +++ b/text_recognizer/tests/support/create_emnist_lines_support_files.py @@ -0,0 +1,51 @@ +"""Module for creating EMNIST Lines test support files.""" +# flake8: noqa: S106 + +from pathlib import Path +import shutil + +import numpy as np + +from text_recognizer.datasets import EmnistLinesDataset +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines" + + +def create_emnist_lines_support_files() -> None: + """Create EMNIST Lines test images.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + # TODO: maybe have to add args to dataset. + dataset = EmnistLinesDataset( + init_token="<sos>", + pad_token="_", + eos_token="<eos>", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, + } + ], + ) # nosec: S106 + dataset.load_or_generate_data() + + for index in [5, 7, 9]: + image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + print(image.sum(), image.dtype) + + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token + ) + print(label) + image = image.numpy() + util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_lines_support_files() diff --git a/text_recognizer/tests/support/create_emnist_support_files.py b/text_recognizer/tests/support/create_emnist_support_files.py new file mode 100644 index 0000000..f9ff030 --- /dev/null +++ b/text_recognizer/tests/support/create_emnist_support_files.py @@ -0,0 +1,30 @@ +"""Module for creating EMNIST test support files.""" +from pathlib import Path +import shutil + +from text_recognizer.datasets import EmnistDataset +from text_recognizer.util import write_image + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" + + +def create_emnist_support_files() -> None: + """Create support images for test of CharacterPredictor class.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + dataset = EmnistDataset(train=False) + dataset.load_or_generate_data() + + for index in [5, 7, 9]: + image, label = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + image = image.numpy() + label = dataset.mapper(int(label)) + print(index, label) + write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_support_files() diff --git a/text_recognizer/tests/support/create_iam_lines_support_files.py b/text_recognizer/tests/support/create_iam_lines_support_files.py new file mode 100644 index 0000000..50f9e3d --- /dev/null +++ b/text_recognizer/tests/support/create_iam_lines_support_files.py @@ -0,0 +1,50 @@ +"""Module for creating IAM Lines test support files.""" +# flake8: noqa +from pathlib import Path +import shutil + +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines" + + +def create_emnist_lines_support_files() -> None: + """Create IAM Lines test images.""" + shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) + SUPPORT_DIRNAME.mkdir() + + # TODO: maybe have to add args to dataset. + dataset = IamLinesDataset( + init_token="<sos>", + pad_token="_", + eos_token="<eos>", + transform=[{"type": "ToTensor", "args": {}}], + target_transform=[ + { + "type": "AddTokens", + "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, + } + ], + ) + dataset.load_or_generate_data() + + for index in [0, 1, 3]: + image, target = dataset[index] + if len(image.shape) == 3: + image = image.squeeze(0) + print(image.sum(), image.dtype) + + label = "".join(dataset.mapper(label) for label in target[1:]).strip( + dataset.mapper.pad_token + ) + print(label) + image = image.numpy() + util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) + + +if __name__ == "__main__": + create_emnist_lines_support_files() diff --git a/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png Binary files differnew file mode 100644 index 0000000..b7d0618 --- /dev/null +++ b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png diff --git a/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png Binary files differnew file mode 100644 index 0000000..14a8cf3 --- /dev/null +++ b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png diff --git a/text_recognizer/tests/support/emnist_lines/they<eos>.png b/text_recognizer/tests/support/emnist_lines/they<eos>.png Binary files differnew file mode 100644 index 0000000..7f05951 --- /dev/null +++ b/text_recognizer/tests/support/emnist_lines/they<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png Binary files differnew file mode 100644 index 0000000..6eeb642 --- /dev/null +++ b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png Binary files differnew file mode 100644 index 0000000..4974cf8 --- /dev/null +++ b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png diff --git a/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png Binary files differnew file mode 100644 index 0000000..a731245 --- /dev/null +++ b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png diff --git a/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg Binary files differnew file mode 100644 index 0000000..d9753b6 --- /dev/null +++ b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg diff --git a/text_recognizer/tests/test_character_predictor.py b/text_recognizer/tests/test_character_predictor.py new file mode 100644 index 0000000..01bda78 --- /dev/null +++ b/text_recognizer/tests/test_character_predictor.py @@ -0,0 +1,31 @@ +"""Test for CharacterPredictor class.""" +import importlib +import os +from pathlib import Path +import unittest + +from loguru import logger + +from text_recognizer.character_predictor import CharacterPredictor +from text_recognizer.networks import MLP + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestCharacterPredictor(unittest.TestCase): + """Tests for the CharacterPredictor class.""" + + def test_filename(self) -> None: + """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" + network_fn_ = MLP + predictor = CharacterPredictor(network_fn=network_fn_) + + for filename in SUPPORT_DIRNAME.glob("*.png"): + pred, conf = predictor.predict(str(filename)) + logger.info( + f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" + ) + self.assertEqual(pred, filename.stem) + self.assertGreater(conf, 0.7) diff --git a/text_recognizer/tests/test_line_predictor.py b/text_recognizer/tests/test_line_predictor.py new file mode 100644 index 0000000..eede4d4 --- /dev/null +++ b/text_recognizer/tests/test_line_predictor.py @@ -0,0 +1,35 @@ +"""Tests for LinePredictor.""" +import os +from pathlib import Path +import unittest + + +import editdistance +import numpy as np + +from text_recognizer.datasets import IamLinesDataset +from text_recognizer.line_predictor import LinePredictor +import text_recognizer.util as util + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestEmnistLinePredictor(unittest.TestCase): + """Test LinePredictor class on the EmnistLines dataset.""" + + def test_filename(self) -> None: + """Test that LinePredictor correctly predicts on single images, for several test images.""" + predictor = LinePredictor( + dataset="EmnistLineDataset", network_fn="CNNTransformer" + ) + + for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): + pred, conf = predictor.predict(str(filename)) + true = str(filename.stem) + edit_distance = editdistance.eval(pred, true) / len(pred) + print( + f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' + ) + self.assertLess(edit_distance, 0.2) diff --git a/text_recognizer/tests/test_paragraph_text_recognizer.py b/text_recognizer/tests/test_paragraph_text_recognizer.py new file mode 100644 index 0000000..3e280b9 --- /dev/null +++ b/text_recognizer/tests/test_paragraph_text_recognizer.py @@ -0,0 +1,37 @@ +"""Test for ParagraphTextRecognizer class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" + +# Prevent using GPU. +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestParagraphTextRecognizor(unittest.TestCase): + """Test that it can take non-square images of max dimension larger than 256px.""" + + def test_filename(self) -> None: + """Test model on support image.""" + line_predictor_args = { + "dataset": "EmnistLineDataset", + "network_fn": "CNNTransformer", + } + line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} + model = ParagraphTextRecognizor( + line_predictor_args=line_predictor_args, + line_detector_args=line_detector_args, + ) + num_text_lines_by_name = {"a01-000u-cropped": 7} + for filename in (SUPPORT_DIRNAME).glob("*.jpg"): + full_image = util.read_image(str(filename), grayscale=True) + predicted_text, line_region_crops = model.predict(full_image) + print(predicted_text) + self.assertTrue( + len(line_region_crops), num_text_lines_by_name[filename.stem] + ) diff --git a/text_recognizer/util.py b/text_recognizer/util.py new file mode 100644 index 0000000..b431e22 --- /dev/null +++ b/text_recognizer/util.py @@ -0,0 +1,52 @@ +"""Utility functions for text_recognizer module.""" +import os +from pathlib import Path +from typing import Union +from urllib.request import urlopen + +import cv2 +import numpy as np + + +def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray: + """Read image_uri.""" + + def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray: + return cv2.imread(str(image_filename), imread_flag) + + def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray: + if image_url.lower().startswith("http"): + url_response = urlopen(str(image_url)) + image_array = np.array(bytearray(url_response.read()), dtype=np.uint8) + return cv2.imdecode(image_array, imread_flag) + else: + raise ValueError( + "Url does not start with http, therefore not safe to open..." + ) from None + + imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + local_file = os.path.exists(image_uri) + image = None + + if local_file: + image = read_image_from_filename(image_uri, imread_flag) + else: + image = read_image_from_url(image_uri, imread_flag) + + if image is None: + raise ValueError(f"Could not load image at {image_uri}") + + return image + + +def rescale_image(image: np.ndarray) -> np.ndarray: + """Rescale image from [0, 1] to [0, 255].""" + if image.max() <= 1.0: + image = 255 * (image - image.min()) / (image.max() - image.min()) + return image + + +def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: + """Write image to file.""" + image = rescale_image(image) + cv2.imwrite(str(filename), image) diff --git a/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt new file mode 100644 index 0000000..344e0a3 --- /dev/null +++ b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46d483950ef0876ba072d06cd94021e08d99c4fa14eeccf22aeae1cbb2066b4f +size 5628749 diff --git a/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt new file mode 100644 index 0000000..f2dfd84 --- /dev/null +++ b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a69e5efedea70c4c5cb8ccdcc8cd480400f6c73e3313423f4dbbfe615644f0a +size 4500617 diff --git a/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt new file mode 100644 index 0000000..e1add8d --- /dev/null +++ b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68dd5c98eedc8753546f88b4e6fd5fc38725dc0079b837c30fb3d48069ec412b +size 15002754 diff --git a/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt Binary files differnew file mode 100644 index 0000000..d9ca01d --- /dev/null +++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt diff --git a/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt Binary files differnew file mode 100644 index 0000000..0af0e57 --- /dev/null +++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt diff --git a/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt Binary files differnew file mode 100644 index 0000000..b5295c2 --- /dev/null +++ b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt |