diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer')
81 files changed, 0 insertions, 7438 deletions
diff --git a/src/text_recognizer/__init__.py b/src/text_recognizer/__init__.py deleted file mode 100644 index 3dc1f76..0000000 --- a/src/text_recognizer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py deleted file mode 100644 index ad71289..0000000 --- a/src/text_recognizer/character_predictor.py +++ /dev/null @@ -1,29 +0,0 @@ -"""CharacterPredictor class.""" -from typing import Dict, Tuple, Type, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import CharacterModel -from text_recognizer.util import read_image - - -class CharacterPredictor: - """Recognizes the character in handwritten character images.""" - - def __init__(self, network_fn: str, dataset: str) -> None: - """Intializes the CharacterModel and load the pretrained weights.""" - network_fn = getattr(networks, network_fn) - dataset = getattr(datasets, dataset) - self.model = CharacterModel(network_fn=network_fn, dataset=dataset) - self.model.eval() - self.model.use_swa_model() - - def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: - """Predict on a single images contianing a handwritten character.""" - if isinstance(image_or_filename, str): - image = read_image(image_or_filename, grayscale=True) - else: - image = image_or_filename - return self.model.predict_on_image(image) diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py deleted file mode 100644 index a6c1c59..0000000 --- a/src/text_recognizer/datasets/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Dataset modules.""" -from .emnist_dataset import EmnistDataset -from .emnist_lines_dataset import ( - construct_image_from_string, - EmnistLinesDataset, - get_samples_by_character, -) -from .iam_dataset import IamDataset -from .iam_lines_dataset import IamLinesDataset -from .iam_paragraphs_dataset import IamParagraphsDataset -from .iam_preprocessor import load_metadata, Preprocessor -from .transforms import AddTokens, Transpose -from .util import ( - _download_raw_dataset, - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -__all__ = [ - "_download_raw_dataset", - "AddTokens", - "compute_sha256", - "construct_image_from_string", - "DATA_DIRNAME", - "download_url", - "EmnistDataset", - "EmnistMapper", - "EmnistLinesDataset", - "get_samples_by_character", - "load_metadata", - "IamDataset", - "IamLinesDataset", - "IamParagraphsDataset", - "Preprocessor", - "Transpose", -] diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/src/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py deleted file mode 100644 index 9884fdf..0000000 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" - -import json -from pathlib import Path -from typing import Callable, Optional, Tuple, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.transforms import Transpose -from text_recognizer.datasets.util import DATA_DIRNAME - - -class EmnistDataset(Dataset): - """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" - - def __init__( - self, - pad_token: str = None, - train: bool = False, - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - seed: int = 4711, - ) -> None: - """Loads the dataset and the mappings. - - Args: - pad_token (str): The pad token symbol. Defaults to _. - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. - seed (int): Seed number. Defaults to 4711. - - """ - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - pad_token=pad_token, - ) - - self.sample_to_balance = sample_to_balance - - # Have to transpose the emnist characters, ToTensor norms input between [0,1]. - if transform is None: - self.transform = Compose([Transpose(), ToTensor()]) - - self.target_transform = None - - self.seed = seed - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, Tensor]): The indices of the samples to fetch. - - Returns: - Tuple[Tensor, Tensor]: Data target tuple. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return ( - "EMNIST Dataset\n" - f"Num classes: {self.num_classes}\n" - f"Input shape: {self.input_shape}\n" - f"Mapping: {self.mapper.mapping}\n" - ) - - def _sample_to_balance(self) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = self._data - y = self._targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - self._data = x_sampled - self._targets = y_sampled - - def load_or_generate_data(self) -> None: - """Fetch the EMNIST dataset.""" - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=self.train, - download=False, - transform=None, - target_transform=None, - ) - - self._data = dataset.data - self._targets = dataset.targets - - if self.sample_to_balance: - self._sample_to_balance() - - if self.subsample_fraction is not None: - self._subsample() diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json deleted file mode 100644 index 2a0648a..0000000 --- a/src/text_recognizer/datasets/emnist_essentials.json +++ /dev/null @@ -1 +0,0 @@ -{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py deleted file mode 100644 index 1992446..0000000 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" - -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose -from text_recognizer.datasets.sentence_generator import SentenceGenerator -from text_recognizer.datasets.util import ( - DATA_DIRNAME, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" - -MAX_WIDTH = 952 - - -class EmnistLinesDataset(Dataset): - """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" - - def __init__( - self, - train: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - subsample_fraction: float = None, - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_samples: int = 10000, - seed: int = 4711, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Set attributes and loads the dataset. - - Args: - train (bool): Flag for the filename. Defaults to False. Defaults to None. - transform (Optional[Callable]): The transform of the data. Defaults to None. - target_transform (Optional[Callable]): The transform of the target. Defaults to None. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - max_length (int): The maximum number of characters. Defaults to 34. - min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. - max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. - num_samples (int): Number of samples to generate. Defaults to 10000. - seed (int): Seed number. Defaults to 4711. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. - - """ - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - transform=transform, - target_transform=target_transform, - subsample_fraction=subsample_fraction, - init_token=init_token, - pad_token=self.pad_token, - eos_token=eos_token, - lower=lower, - ) - - # Extract dataset information. - self._input_shape = self._mapper.input_shape - self.num_classes = self._mapper.num_classes - - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_samples = num_samples - self._input_shape = ( - self.input_shape[0], - self.input_shape[1] * self.max_length, - ) - self._output_shape = (self.max_length, self.num_classes) - self.seed = seed - - # Placeholders for the dataset. - self._data = None - self._target = None - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return ( - "EMNIST Lines Dataset\n" # pylint: disable=no-member - f"Max length: {self.max_length}\n" - f"Min overlap: {self.min_overlap}\n" - f"Max overlap: {self.max_overlap}\n" - f"Num classes: {self.num_classes}\n" - f"Input shape: {self.input_shape}\n" - f"Data: {self.data.shape}\n" - f"Tagets: {self.targets.shape}\n" - ) - - @property - def data_filename(self) -> Path: - """Path to the h5 file.""" - filename = "train.pt" if self.train else "test.pt" - return DATA_DIRNAME / filename - - def load_or_generate_data(self) -> None: - """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" - np.random.seed(self.seed) - - if not self.data_filename.exists(): - self._generate_data() - self._load_data() - self._subsample() - - def _load_data(self) -> None: - """Loads the dataset from the h5 file.""" - logger.debug("EmnistLinesDataset loading data from HDF5...") - with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][()] - self._targets = f["targets"][()] - - def _generate_data(self) -> str: - """Generates a dataset with the Brown corpus and Emnist characters.""" - logger.debug("Generating data...") - - sentence_generator = SentenceGenerator(self.max_length) - - # Load emnist dataset. - emnist = EmnistDataset( - train=self.train, sample_to_balance=True, pad_token=self.pad_token - ) - emnist.load_or_generate_data() - - samples_by_character = get_samples_by_character( - emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, - ) - - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "a") as f: - data, targets = create_dataset_of_images( - self.num_samples, - samples_by_character, - sentence_generator, - self.min_overlap, - self.max_overlap, - ) - - targets = convert_strings_to_categorical_labels( - targets, emnist.inverse_mapping - ) - - f.create_dataset("data", data=data, dtype="u1", compression="lzf") - f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") - - -def get_samples_by_character( - samples: np.ndarray, labels: np.ndarray, mapping: Dict -) -> defaultdict: - """Creates a dictionary with character as key and value as the list of images of that character. - - Args: - samples (np.ndarray): Dataset of images of characters. - labels (np.ndarray): The labels for each image. - mapping (Dict): The Emnist mapping dictionary. - - Returns: - defaultdict: A dictionary with characters as keys and list of images as values. - - """ - samples_by_character = defaultdict(list) - for sample, label in zip(samples, labels.flatten()): - samples_by_character[mapping[label]].append(sample) - return samples_by_character - - -def select_letter_samples_for_string( - string: str, samples_by_character: Dict -) -> List[np.ndarray]: - """Randomly selects Emnist characters to use for the senetence. - - Args: - string (str): The word or sentence. - samples_by_character (Dict): The dictionary of emnist images of each character. - - Returns: - List[np.ndarray]: A list of emnist images of the string. - - """ - zero_image = np.zeros((28, 28), np.uint8) - sample_image_by_character = {} - for character in string: - if character in sample_image_by_character: - continue - samples = samples_by_character[character] - sample = samples[np.random.choice(len(samples))] if samples else zero_image - sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) - return [sample_image_by_character[character] for character in string] - - -def construct_image_from_string( - string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float -) -> np.ndarray: - """Concatenates images of the characters in the string. - - The concatination is made with randomly selected overlap so that some portion of the character will overlap. - - Args: - string (str): The word or sentence. - samples_by_character (Dict): The dictionary of emnist images of each character. - min_overlap (float): Minimum amount of overlap between Emnist images. - max_overlap (float): Maximum amount of overlap between Emnist images. - - Returns: - np.ndarray: The Emnist image of the string. - - """ - overlap = np.random.uniform(min_overlap, max_overlap) - sampled_images = select_letter_samples_for_string(string, samples_by_character) - length = len(sampled_images) - height, width = sampled_images[0].shape - next_overlap_width = width - int(overlap * width) - concatenated_image = np.zeros((height, width * length), np.uint8) - x = 0 - for image in sampled_images: - concatenated_image[:, x : (x + width)] += image - x += next_overlap_width - - if concatenated_image.shape[-1] > MAX_WIDTH: - concatenated_image = Tensor(concatenated_image).unsqueeze(0) - concatenated_image = F.interpolate( - concatenated_image, size=MAX_WIDTH, mode="nearest" - ) - concatenated_image = concatenated_image.squeeze(0).numpy() - - return np.minimum(255, concatenated_image) - - -def create_dataset_of_images( - length: int, - samples_by_character: Dict, - sentence_generator: SentenceGenerator, - min_overlap: float, - max_overlap: float, -) -> Tuple[np.ndarray, List[str]]: - """Creates a dataset with images and labels from strings generated from the SentenceGenerator. - - Args: - length (int): The number of characters for each string. - samples_by_character (Dict): The dictionary of emnist images of each character. - sentence_generator (SentenceGenerator): A SentenceGenerator objest. - min_overlap (float): Minimum amount of overlap between Emnist images. - max_overlap (float): Maximum amount of overlap between Emnist images. - - Returns: - Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). - - Raises: - RuntimeError: If the sentence generator is not able to generate a string. - - """ - sample_label = sentence_generator.generate() - sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) - images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) - labels = [] - for n in range(length): - label = None - # Try several times to generate before actually throwing an error. - for _ in range(10): - try: - label = sentence_generator.generate() - break - except Exception: # pylint: disable=broad-except - pass - if label is None: - raise RuntimeError("Was not able to generate a valid string.") - images[n] = construct_image_from_string( - label, samples_by_character, min_overlap, max_overlap - ) - labels.append(label) - return images, labels - - -def convert_strings_to_categorical_labels( - labels: List[str], mapping: Dict -) -> np.ndarray: - """Translates a string of characters in to a target array of class int.""" - return np.array([[mapping[c] for c in label] for label in labels]) - - -@click.command() -@click.option( - "--max_length", type=int, default=34, help="Number of characters in a sentence." -) -@click.option( - "--min_overlap", type=float, default=0.0, help="Min overlap between characters." -) -@click.option( - "--max_overlap", type=float, default=0.33, help="Max overlap between characters." -) -@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") -@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") -def create_datasets( - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_train: int = 10000, - num_test: int = 1000, -) -> None: - """Creates a training an validation dataset of Emnist lines.""" - num_samples = [num_train, num_test] - for num, train in zip(num_samples, [True, False]): - emnist_lines = EmnistLinesDataset( - train=train, - max_length=max_length, - min_overlap=min_overlap, - max_overlap=max_overlap, - num_samples=num, - ) - emnist_lines.load_or_generate_data() - - -if __name__ == "__main__": - create_datasets() diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py deleted file mode 100644 index f4a869d..0000000 --- a/src/text_recognizer/datasets/iam_dataset.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. - - -class IamDataset: - """IAM dataset. - - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - - The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. - The validation set has been merged into the train set. - The train set has 7,101 lines from 326 writers. - The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - - """ - - def __init__(self) -> None: - self.metadata = toml.load(METADATA_FILENAME) - - def load_or_generate_data(self) -> None: - """Downloads IAM dataset if xml files does not exist.""" - if not self.xml_filenames: - self._download_iam() - - @property - def xml_filenames(self) -> List: - """List of xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - - @property - def form_filenames(self) -> List: - """List of forms filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - - def _download_iam(self) -> None: - curdir = os.getcwd() - os.chdir(RAW_DATA_DIRNAME) - _download_raw_dataset(self.metadata) - _extract_raw_dataset(self.metadata) - os.chdir(curdir) - - @property - def form_filenames_by_id(self) -> Dict: - """Creates a dictionary with filenames as keys and forms as values.""" - return {filename.stem: filename for filename in self.form_filenames} - - @cachedproperty - def line_strings_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of line texts in it.""" - return { - filename.stem: _get_line_strings_from_xml_file(filename) - for filename in self.xml_filenames - } - - @cachedproperty - def line_regions_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return { - filename.stem: _get_line_regions_from_xml_file(filename) - for filename in self.xml_filenames - } - - def __repr__(self) -> str: - """Print info about dataset.""" - return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: - logger.info("Extracting IAM data.") - with zipfile.ZipFile(metadata["filename"], "r") as zip_file: - zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: - """Get the text content of each line. Note that we replace " with ".""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: - """Get the line region dict for each line.""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: - """Extracts coordinates for each line of text.""" - # TODO: fix input! - word_elements = xml_line.findall("word/cmp") - x1s = [int(el.attrib["x"]) for el in word_elements] - y1s = [int(el.attrib["y"]) for el in word_elements] - x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] - return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } - - -def main() -> None: - """Initializes the dataset and print info about the dataset.""" - dataset = IamDataset() - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( - "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): - """IAM lines datasets for handwritten text lines.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - init_token=init_token, - pad_token=pad_token, - eos_token=eos_token, - lower=lower, - ) - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self.data.shape[1:] if self.data is not None else None - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return ( - self.targets.shape[1:] + (self.num_classes,) - if self.targets is not None - else None - ) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - if not PROCESSED_DATA_FILENAME.exists(): - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info("Downloading IAM lines...") - download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self._data = f[f"x_{self.split}"][:] - self._targets = f[f"y_{self.split}"][:] - self._subsample() - - def __repr__(self) -> str: - """Print info about the dataset.""" - return ( - "IAM Lines Dataset\n" # pylint: disable=no-member - f"Number classes: {self.num_classes}\n" - f"Mapping: {self.mapper.mapping}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - ) - # Load Iam dataset. - self.iam_dataset = IamDataset() - - self.num_classes = 3 - self._input_shape = (256, 256) - self._output_shape = self._input_shape + (self.num_classes,) - self._ids = None - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - seed = np.random.randint(SEED) - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.transform: - data = self.transform(data) - - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets.long() - - @property - def ids(self) -> Tensor: - """Ids of the dataset.""" - return self._ids - - def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: - """Get data target pair from id.""" - ind = self.ids.index(id_) - return self.data[ind], self.targets[ind] - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) - num_targets = len(self.iam_dataset.line_regions_by_id) - - if num_actual < num_targets - 2: - self._process_iam_paragraphs() - - self._data, self._targets, self._ids = _load_iam_paragraphs() - self._get_random_split() - self._subsample() - - def _get_random_split(self) -> None: - np.random.seed(SEED) - num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) - indices = np.random.permutation(self.data.shape[0]) - train_indices, test_indices = indices[:num_train], indices[num_train:] - if self.train: - self._data = self.data[train_indices] - self._targets = self.targets[train_indices] - else: - self._data = self.data[test_indices] - self._targets = self.targets[test_indices] - - def _process_iam_paragraphs(self) -> None: - """Crop the part with the text. - - For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are - self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel - corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line - """ - crop_dims = self._decide_on_crop_dims() - CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - GT_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info( - f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" - ) - for filename in self.iam_dataset.form_filenames: - id_ = filename.stem - line_region = self.iam_dataset.line_regions_by_id[id_] - _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - - def _decide_on_crop_dims(self) -> Tuple[int, int]: - """Decide on the dimensions to crop out of the form image. - - Since image width is larger than a comfortable crop around the longest paragraph, - we will make the crop a square form factor. - And since the found dimensions 610x610 are pretty close to 512x512, - we might as well resize crops and make it exactly that, which lets us - do all kinds of power-of-2 pooling and upsampling should we choose to. - - Returns: - Tuple[int, int]: A tuple of crop dimensions. - - Raises: - RuntimeError: When max crop height is larger than max crop width. - - """ - - sample_form_filename = self.iam_dataset.form_filenames[0] - sample_image = util.read_image(sample_form_filename, grayscale=True) - max_crop_width = sample_image.shape[1] - max_crop_height = _get_max_paragraph_crop_height( - self.iam_dataset.line_regions_by_id - ) - if not max_crop_height <= max_crop_width: - raise RuntimeError( - f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" - ) - - crop_dims = (max_crop_width, max_crop_width) - logger.info( - f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." - ) - logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") - return crop_dims - - def __repr__(self) -> str: - """Return info about the dataset.""" - return ( - "IAM Paragraph Dataset\n" # pylint: disable=no-member - f"Num classes: {self.num_classes}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: - heights = [] - for regions in line_regions_by_id.values(): - min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - heights.append(height) - return max(heights) - - -def _crop_paragraph_image( - filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: - image = util.read_image(filename, grayscale=True) - - min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - crop_height = crop_dims[0] - buffer = (crop_height - height) // 2 - - # Generate image crop. - image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) - try: - image_crop[buffer : buffer + height] = image[min_y1:max_y2] - except Exception as e: # pylint: disable=broad-except - logger.error(f"Rescued {filename}: {e}") - return - - # Generate ground truth. - gt_image = np.zeros_like(image_crop, dtype=np.uint8) - for index, region in enumerate(line_regions): - gt_image[ - (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), - region["x1"] : region["x2"], - ] = (index % 2 + 1) - - # Generate image for debugging. - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("Set1") - image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) - for index, region in enumerate(line_regions): - color = [255 * _ for _ in cmap(index)[:-1]] - cv2.rectangle( - image_crop_for_debug, - (region["x1"], region["y1"] - min_y1 + buffer), - (region["x2"], region["y2"] - min_y1 + buffer), - color, - 3, - ) - image_crop_for_debug = cv2.resize( - image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA - ) - util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - - image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) - util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - - gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) - util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: - logger.info("Loading IAM paragraph crops and ground truth from image files...") - images = [] - gt_images = [] - ids = [] - for filename in CROPS_DIRNAME.glob("*.jpg"): - id_ = filename.stem - image = util.read_image(filename, grayscale=True) - image = 1.0 - image / 255 - - gt_filename = GT_DIRNAME / f"{id_}.png" - gt_image = util.read_image(gt_filename, grayscale=True) - - images.append(image) - gt_images.append(gt_image) - ids.append(id_) - images = np.array(images).astype(np.float32) - gt_images = np.array(gt_images).astype(np.uint8) - ids = np.array(ids) - return images, gt_images, ids - - -@click.command() -@click.option( - "--subsample_fraction", - type=float, - default=None, - help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: - """Load dataset and print info.""" - logger.info("Creating train set...") - dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - logger.info("Creating test set...") - dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py deleted file mode 100644 index a93eb00..0000000 --- a/src/text_recognizer/datasets/iam_preprocessor.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Preprocessor for extracting word letters from the IAM dataset. - -The code is mostly stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - -""" - -import collections -import itertools -from pathlib import Path -import re -from typing import List, Optional, Union - -import click -from loguru import logger -import torch - - -def load_metadata( - data_dir: Path, wordsep: str, use_words: bool = False -) -> collections.defaultdict: - """Loads IAM metadata and returns it as a dictionary.""" - forms = collections.defaultdict(list) - filename = "words.txt" if use_words else "lines.txt" - - with open(data_dir / "ascii" / filename, "r") as f: - lines = (line.strip().split() for line in f if line[0] != "#") - for line in lines: - # Skip word segmentation errors. - if use_words and line[1] == "err": - continue - text = " ".join(line[8:]) - - # Remove garbage tokens: - text = text.replace("#", "") - - # Swap word sep form | to wordsep - text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) - form_key = "-".join(line[0].split("-")[:2]) - line_key = "-".join(line[0].split("-")[:3]) - box_idx = 4 - use_words - box = tuple(int(val) for val in line[box_idx : box_idx + 4]) - forms[form_key].append({"key": line_key, "box": box, "text": text}) - return forms - - -class Preprocessor: - """A preprocessor for the IAM dataset.""" - - # TODO: add lower case only to when generating... - - def __init__( - self, - data_dir: Union[str, Path], - num_features: int, - tokens_path: Optional[Union[str, Path]] = None, - lexicon_path: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - self.wordsep = "▁" - self._use_word = use_words - self._prepend_wordsep = prepend_wordsep - - self.data_dir = Path(data_dir) - - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) - - # Load the set of graphemes: - graphemes = set() - for _, form in self.forms.items(): - for line in form: - graphemes.update(line["text"].lower()) - self.graphemes = sorted(graphemes) - - # Build the token-to-index and index-to-token maps. - if tokens_path is not None: - with open(tokens_path, "r") as f: - self.tokens = [line.strip() for line in f] - else: - self.tokens = self.graphemes - - if lexicon_path is not None: - with open(lexicon_path, "r") as f: - lexicon = (line.strip().split() for line in f) - lexicon = {line[0]: line[1:] for line in lexicon} - self.lexicon = lexicon - else: - self.lexicon = None - - self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} - self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} - self.num_features = num_features - self.text = [] - - @property - def num_tokens(self) -> int: - """Returns the number or tokens.""" - return len(self.tokens) - - @property - def use_words(self) -> bool: - """If words are used.""" - return self._use_word - - def extract_train_text(self) -> None: - """Extracts training text.""" - keys = [] - with open(self.data_dir / "task" / "trainset.txt") as f: - keys.extend((line.strip() for line in f)) - - for _, examples in self.forms.items(): - for example in examples: - if example["key"] not in keys: - continue - self.text.append(example["text"].lower()) - - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" - token_to_index = self.graphemes_to_index - if self.lexicon is not None: - if len(line) > 0: - # If the word is not found in the lexicon, fall back to letters. - line = [ - t - for w in line.split(self.wordsep) - for t in self.lexicon.get(w, self.wordsep + w) - ] - token_to_index = self.tokens_to_index - if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) - - def to_text(self, indices: List[int]) -> str: - """Converts indices to text.""" - # Roughly the inverse of `to_index` - encoding = self.graphemes - if self.lexicon is not None: - encoding = self.tokens - return self._post_process(encoding[i] for i in indices) - - def tokens_to_text(self, indices: List[int]) -> str: - """Converts tokens to text.""" - return self._post_process(self.tokens[i] for i in indices) - - def _post_process(self, indices: List[int]) -> str: - """A list join.""" - return "".join(indices).strip(self.wordsep) - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") -@click.option( - "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" -) -@click.option( - "--save_text", type=str, default=None, help="Path to save parsed train text" -) -@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") -def cli( - data_dir: Optional[str], - use_words: bool, - save_text: Optional[str], - save_tokens: Optional[str], -) -> None: - """CLI for extracting text data from the iam dataset.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - preprocessor = Preprocessor(data_dir, 64, use_words=use_words) - preprocessor.extract_train_text() - - processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") - - if save_text is not None: - logger.info("Saving training text") - with open(processed_dir / save_text, "w") as f: - f.write("\n".join(t for t in preprocessor.text)) - - if save_tokens is not None: - logger.info("Saving tokens") - with open(processed_dir / save_tokens, "w") as f: - f.write("\n".join(preprocessor.tokens)) - - -if __name__ == "__main__": - cli() diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py deleted file mode 100644 index dd76652..0000000 --- a/src/text_recognizer/datasets/sentence_generator.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Downloading the Brown corpus with NLTK for sentence generating.""" - -import itertools -import re -import string -from typing import Optional - -import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView -import numpy as np - -from text_recognizer.datasets.util import DATA_DIRNAME - -NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" - - -class SentenceGenerator: - """Generates text sentences using the Brown corpus.""" - - def __init__(self, max_length: Optional[int] = None) -> None: - """Loads the corpus and sets word start indices.""" - self.corpus = brown_corpus() - self.word_start_indices = [0] + [ - _.start(0) + 1 for _ in re.finditer(" ", self.corpus) - ] - self.max_length = max_length - - def generate(self, max_length: Optional[int] = None) -> str: - """Generates a word or sentences from the Brown corpus. - - Sample a string from the Brown corpus of length at least one word and at most max_length, padding to - max_length with the '_' characters if sentence is shorter. - - Args: - max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. - - Returns: - str: A sentence from the Brown corpus. - - Raises: - ValueError: If max_length was not specified at initialization and not given as an argument. - - """ - if max_length is None: - max_length = self.max_length - if max_length is None: - raise ValueError( - "Must provide max_length to this method or when making this object." - ) - - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - padding = "_" * (max_length - len(sampled_text)) - return sampled_text + padding - - -def brown_corpus() -> str: - """Returns a single string with the Brown corpus with all punctuations stripped.""" - sentences = load_nltk_brown_corpus() - corpus = " ".join(itertools.chain.from_iterable(sentences)) - corpus = corpus.translate({ord(c): None for c in string.punctuation}) - corpus = re.sub(" +", " ", corpus) - return corpus - - -def load_nltk_brown_corpus() -> ConcatenatedCorpusView: - """Load the Brown corpus using the NLTK library.""" - nltk.data.path.append(NLTK_DATA_DIRNAME) - try: - nltk.corpus.brown.sents() - except LookupError: - NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) - return nltk.corpus.brown.sents() diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py deleted file mode 100644 index b6a48f5..0000000 --- a/src/text_recognizer/datasets/transforms.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Transforms for PyTorch datasets.""" -from abc import abstractmethod -from pathlib import Path -import random -from typing import Any, Optional, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( - ColorJitter, - Compose, - Normalize, - RandomAffine, - RandomHorizontalFlip, - RandomRotation, - ToPILImage, - ToTensor, -) - -from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: - """Image transform with random resize and crop applied. - - Stolen from - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - - """ - - def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: - self.jitter = jitter - self.ratio = ratio - - def __call__(self, img: np.ndarray) -> np.ndarray: - """Applies random crop and rotation to an image.""" - w, h = img.size - - # pad with white: - img = transforms.functional.pad(img, self.jitter, fill=255) - - # crop at random (x, y): - x = self.jitter + random.randint(-self.jitter, self.jitter) - y = self.jitter + random.randint(-self.jitter, self.jitter) - - # randomize aspect ratio: - size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) - size = (h, int(size_w)) - img = transforms.functional.resized_crop(img, y, x, h, w, size) - return img - - -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - -class Resize: - """Resizes a tensor to a specified width.""" - - def __init__(self, width: int = 952) -> None: - # The default is 952 because of the IAM dataset. - self.width = width - - def __call__(self, image: Tensor) -> Tensor: - """Resize tensor in the last dimension.""" - return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: - """Adds start of sequence and end of sequence tokens to target tensor.""" - - def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, - ) - self.pad_value = self.emnist_mapper(self.pad_token) - self.eos_value = self.emnist_mapper(self.eos_token) - - def __call__(self, target: Tensor) -> Tensor: - """Adds a sos token to the begining and a eos token to the end of a target sequence.""" - dtype, device = target.dtype, target.device - - # Find the where padding starts. - pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - - target[pad_index] = self.eos_value - - if self.init_token is not None: - self.sos_value = self.emnist_mapper(self.init_token) - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) - target = torch.cat([sos, target], dim=0) - - return target - - -class ApplyContrast: - """Sets everything below a threshold to zero, i.e. increase contrast.""" - - def __init__(self, low: float = 0.0, high: float = 0.25) -> None: - self.low = low - self.high = high - - def __call__(self, x: Tensor) -> Tensor: - """Apply mask binary mask to input tensor.""" - mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) - return x * mask - - -class Unsqueeze: - """Add a dimension to the tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Adds dim.""" - return x.unsqueeze(0) - - -class Squeeze: - """Removes the first dimension of a tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Removes first dim.""" - return x.squeeze(0) - - -class ToLower: - """Converts target to lower case.""" - - def __call__(self, target: Tensor) -> Tensor: - """Corrects index value in target tensor.""" - device = target.device - return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) - - -class ToCharcters: - """Converts integers to characters.""" - - def __init__( - self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True - ) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=lower, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, lower=lower - ) - - def __call__(self, y: Tensor) -> str: - """Converts a Tensor to a str.""" - return ( - "".join([self.emnist_mapper(int(i)) for i in y]) - .strip("_") - .replace(" ", "▁") - ) - - -class WordPieces: - """Abstract transform for word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - self.preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - ) - - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """Transforms input.""" - ... - - -class ToWordPieces(WordPieces): - """Transforms str to word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, line: str) -> Tensor: - """Transforms str to word pieces.""" - return self.preprocessor.to_index(line) - - -class ToText(WordPieces): - """Takes word pieces and converts them to text.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, x: Tensor) -> str: - """Converts tensor to text.""" - return self.preprocessor.to_text(x.tolist()) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py deleted file mode 100644 index da87756..0000000 --- a/src/text_recognizer/datasets/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__( - self, - pad_token: str, - init_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - self.lower = lower - - self.essentials = self._load_emnist_essentials() - # Load dataset information. - self._mapping = dict(self.essentials["mapping"]) - self._augment_emnist_mapping() - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if ( - (isinstance(token, np.uint8) or isinstance(token, int)) - or torch.is_tensor(token) - and int(token) in self.mapping - ): - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self) -> None: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - if self.lower: - self._mapping = { - k: str(v) - for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) - } - - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol, and acts as blank symbol as well. - extra_symbols.append(self.pad_token) - - if self.init_token is not None: - extra_symbols.append(self.init_token) - - if self.eos_token is not None: - extra_symbols.append(self.eos_token) - - max_key = max(self.mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: - """Returns the SHA256 checksum of a file.""" - with open(filename, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: - if os.path.exists(metadata["filename"]): - return - logger.info(f"Downloading raw dataset from {metadata['url']}...") - download_url(metadata["url"], metadata["filename"]) - logger.info("Computing SHA-256...") - sha256 = compute_sha256(metadata["filename"]) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) diff --git a/src/text_recognizer/line_predictor.py b/src/text_recognizer/line_predictor.py deleted file mode 100644 index 8e348fe..0000000 --- a/src/text_recognizer/line_predictor.py +++ /dev/null @@ -1,28 +0,0 @@ -"""LinePredictor class.""" -import importlib -from typing import Tuple, Union - -import numpy as np -from torch import nn - -from text_recognizer import datasets, networks -from text_recognizer.models import TransformerModel -from text_recognizer.util import read_image - - -class LinePredictor: - """Given an image of a line of handwritten text, recognizes the text content.""" - - def __init__(self, dataset: str, network_fn: str) -> None: - network_fn = getattr(networks, network_fn) - dataset = getattr(datasets, dataset) - self.model = TransformerModel(network_fn=network_fn, dataset=dataset) - self.model.eval() - - def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: - """Predict on a single images contianing a handwritten character.""" - if isinstance(image_or_filename, str): - image = read_image(image_or_filename, grayscale=True) - else: - image = image_or_filename - return self.model.predict_on_image(image) diff --git a/src/text_recognizer/models/__init__.py b/src/text_recognizer/models/__init__.py deleted file mode 100644 index 7647d7e..0000000 --- a/src/text_recognizer/models/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Model modules.""" -from .base import Model -from .character_model import CharacterModel -from .crnn_model import CRNNModel -from .ctc_transformer_model import CTCTransformerModel -from .segmentation_model import SegmentationModel -from .transformer_model import TransformerModel -from .vqvae_model import VQVAEModel - -__all__ = [ - "CharacterModel", - "CRNNModel", - "CTCTransformerModel", - "Model", - "SegmentationModel", - "TransformerModel", - "VQVAEModel", -] diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py deleted file mode 100644 index 70f4cdb..0000000 --- a/src/text_recognizer/models/base.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Abstract Model class for PyTorch neural networks.""" - -from abc import ABC, abstractmethod -from glob import glob -import importlib -from pathlib import Path -import re -import shutil -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -from loguru import logger -import torch -from torch import nn -from torch import Tensor -from torch.optim.swa_utils import AveragedModel, SWALR -from torch.utils.data import DataLoader, Dataset, random_split -from torchsummary import summary - -from text_recognizer import datasets -from text_recognizer import networks -from text_recognizer.datasets import EmnistMapper - -WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" - - -class Model(ABC): - """Abstract Model class with composition of different parts defining a PyTorch neural network.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Base class, to be inherited by model for specific type of data. - - Args: - network_fn (str): The name of network. - dataset (str): The name dataset class. - network_args (Optional[Dict]): Arguments for the network. Defaults to None. - dataset_args (Optional[Dict]): Arguments for the dataset. - metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. - criterion (Optional[Callable]): The criterion to evaluate the performance of the network. - Defaults to None. - criterion_args (Optional[Dict]): Dict of arguments for criterion. Defaults to None. - optimizer (Optional[Callable]): The optimizer for updating the weights. Defaults to None. - optimizer_args (Optional[Dict]): Dict of arguments for optimizer. Defaults to None. - lr_scheduler (Optional[Callable]): A PyTorch learning rate scheduler. Defaults to None. - lr_scheduler_args (Optional[Dict]): Dict of arguments for learning rate scheduler. Defaults to - None. - swa_args (Optional[Dict]): Dict of arguments for stochastic weight averaging. Defaults to - None. - device (Optional[str]): Name of the device to train on. Defaults to None. - - """ - self._name = f"{self.__class__.__name__}_{dataset}_{network_fn}" - # Has to be set in subclass. - self._mapper = None - - # Placeholder. - self._input_shape = None - - self.dataset_name = dataset - self.dataset = None - self.dataset_args = dataset_args - - # Placeholders for datasets. - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - - # Stochastic Weight Averaging placeholders. - self.swa_args = swa_args - self._swa_scheduler = None - self._swa_network = None - self._use_swa_model = False - - # Experiment directory. - self.model_dir = None - - # Flag for configured model. - self.is_configured = False - self.data_prepared = False - - # Flag for stopping training. - self.stop_training = False - - self._metrics = metrics if metrics is not None else None - - # Set the device. - self._device = ( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - if device is None - else device - ) - - # Configure network. - self._network = None - self._network_args = network_args - self._configure_network(network_fn) - - # Place network on device (GPU). - self.to_device() - - # Loss and Optimizer placeholders for before loading. - self._criterion = criterion - self.criterion_args = criterion_args - - self._optimizer = optimizer - self.optimizer_args = optimizer_args - - self._lr_scheduler = lr_scheduler - self.lr_scheduler_args = lr_scheduler_args - - def configure_model(self) -> None: - """Configures criterion and optimizers.""" - if not self.is_configured: - self._configure_criterion() - self._configure_optimizers() - - # Set this flag to true to prevent the model from configuring again. - self.is_configured = True - - def prepare_data(self) -> None: - """Prepare data for training.""" - # TODO add downloading. - if not self.data_prepared: - # Load dataset module. - self.dataset = getattr(datasets, self.dataset_name) - - # Load train dataset. - train_dataset = self.dataset(train=True, **self.dataset_args["args"]) - train_dataset.load_or_generate_data() - - # Set input shape. - self._input_shape = train_dataset.input_shape - - # Split train dataset into a training and validation partition. - dataset_len = len(train_dataset) - train_len = int( - self.dataset_args["train_args"]["train_fraction"] * dataset_len - ) - val_len = dataset_len - train_len - self.train_dataset, self.val_dataset = random_split( - train_dataset, lengths=[train_len, val_len] - ) - - # Load test dataset. - self.test_dataset = self.dataset(train=False, **self.dataset_args["args"]) - self.test_dataset.load_or_generate_data() - - # Set the flag to true to disable ability to load data again. - self.data_prepared = True - - def train_dataloader(self) -> DataLoader: - """Returns data loader for training set.""" - return DataLoader( - self.train_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: - """Returns data loader for validation set.""" - return DataLoader( - self.val_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=True, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: - """Returns data loader for test set.""" - return DataLoader( - self.test_dataset, - batch_size=self.dataset_args["train_args"]["batch_size"], - num_workers=self.dataset_args["train_args"]["num_workers"], - shuffle=False, - pin_memory=True, - ) - - def _configure_network(self, network_fn: Type[nn.Module]) -> None: - """Loads the network.""" - # If no network arguments are given, load pretrained weights if they exist. - # Load network module. - network_fn = getattr(networks, network_fn) - if self._network_args is None: - self.load_weights(network_fn) - else: - self._network = network_fn(**self._network_args) - - def _configure_criterion(self) -> None: - """Loads the criterion.""" - self._criterion = ( - self._criterion(**self.criterion_args) - if self._criterion is not None - else None - ) - - def _configure_optimizers(self,) -> None: - """Loads the optimizers.""" - if self._optimizer is not None: - self._optimizer = self._optimizer( - self._network.parameters(), **self.optimizer_args - ) - else: - self._optimizer = None - - if self._optimizer and self._lr_scheduler is not None: - if "steps_per_epoch" in self.lr_scheduler_args: - self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - - # Assume lr scheduler should update at each epoch if not specified. - if "interval" not in self.lr_scheduler_args: - interval = "epoch" - else: - interval = self.lr_scheduler_args.pop("interval") - self._lr_scheduler = { - "lr_scheduler": self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ), - "interval": interval, - } - - if self.swa_args is not None: - self._swa_scheduler = { - "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), - "swa_start": self.swa_args["start"], - } - self._swa_network = AveragedModel(self._network).to(self.device) - - @property - def name(self) -> str: - """Returns the name of the model.""" - return self._name - - @property - def input_shape(self) -> Tuple[int, ...]: - """The input shape.""" - return self._input_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the mapper that maps between ints and chars.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Returns the mapping between network output and Emnist character.""" - return self._mapper.mapping if self._mapper is not None else None - - def eval(self) -> None: - """Sets the network to evaluation mode.""" - self._network.eval() - - def train(self) -> None: - """Sets the network to train mode.""" - self._network.train() - - @property - def device(self) -> str: - """Device where the weights are stored, i.e. cpu or cuda.""" - return self._device - - @property - def metrics(self) -> Optional[Dict]: - """Metrics.""" - return self._metrics - - @property - def criterion(self) -> Optional[Callable]: - """Criterion.""" - return self._criterion - - @property - def optimizer(self) -> Optional[Callable]: - """Optimizer.""" - return self._optimizer - - @property - def lr_scheduler(self) -> Optional[Dict]: - """Returns a directory with the learning rate scheduler.""" - return self._lr_scheduler - - @property - def swa_scheduler(self) -> Optional[Dict]: - """Returns a directory with the stochastic weight averaging scheduler.""" - return self._swa_scheduler - - @property - def swa_network(self) -> Optional[Callable]: - """Returns the stochastic weight averaging network.""" - return self._swa_network - - @property - def network(self) -> Type[nn.Module]: - """Neural network.""" - # Returns the SWA network if available. - return self._network - - @property - def weights_filename(self) -> str: - """Filepath to the network weights.""" - WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) - return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") - - def use_swa_model(self) -> None: - """Set to use predictions from SWA model.""" - if self.swa_network is not None: - self._use_swa_model = True - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass with the network.""" - if self._use_swa_model: - return self.swa_network(x) - else: - return self.network(x) - - def summary( - self, - input_shape: Optional[Union[List, Tuple]] = None, - depth: int = 3, - device: Optional[str] = None, - ) -> None: - """Prints a summary of the network architecture.""" - device = self.device if device is None else device - - if input_shape is not None: - summary(self.network, input_shape, depth=depth, device=device) - elif self._input_shape is not None: - input_shape = tuple(self._input_shape) - summary(self.network, input_shape, depth=depth, device=device) - else: - logger.warning("Could not print summary as input shape is not set.") - - def to_device(self) -> None: - """Places the network on the device (GPU).""" - self._network.to(self._device) - - def _get_state_dict(self) -> Dict: - """Get the state dict of the model.""" - state = {"model_state": self._network.state_dict()} - - if self._optimizer is not None: - state["optimizer_state"] = self._optimizer.state_dict() - - if self._lr_scheduler is not None: - state["scheduler_state"] = self._lr_scheduler["lr_scheduler"].state_dict() - state["scheduler_interval"] = self._lr_scheduler["interval"] - - if self._swa_network is not None: - state["swa_network"] = self._swa_network.state_dict() - - return state - - def load_from_checkpoint(self, checkpoint_path: Union[str, Path]) -> None: - """Load a previously saved checkpoint. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - - """ - checkpoint_path = Path(checkpoint_path) - self.prepare_data() - self.configure_model() - logger.debug("Loading checkpoint...") - if not checkpoint_path.exists(): - logger.debug("File does not exist {str(checkpoint_path)}") - - checkpoint = torch.load(str(checkpoint_path), map_location=self.device) - self._network.load_state_dict(checkpoint["model_state"]) - - if self._optimizer is not None: - self._optimizer.load_state_dict(checkpoint["optimizer_state"]) - - if self._lr_scheduler is not None: - # Does not work when loading from previous checkpoint and trying to train beyond the last max epochs - # with OneCycleLR. - if self._lr_scheduler["lr_scheduler"].__class__.__name__ != "OneCycleLR": - self._lr_scheduler["lr_scheduler"].load_state_dict( - checkpoint["scheduler_state"] - ) - self._lr_scheduler["interval"] = checkpoint["scheduler_interval"] - - if self._swa_network is not None: - self._swa_network.load_state_dict(checkpoint["swa_network"]) - - def save_checkpoint( - self, checkpoint_path: Path, is_best: bool, epoch: int, val_metric: str - ) -> None: - """Saves a checkpoint of the model. - - Args: - checkpoint_path (Path): Path to the experiment with the checkpoint. - is_best (bool): If it is the currently best model. - epoch (int): The epoch of the checkpoint. - val_metric (str): Validation metric. - - """ - state = self._get_state_dict() - state["is_best"] = is_best - state["epoch"] = epoch - state["network_args"] = self._network_args - - checkpoint_path.mkdir(parents=True, exist_ok=True) - - logger.debug("Saving checkpoint...") - filepath = str(checkpoint_path / "last.pt") - torch.save(state, filepath) - - if is_best: - logger.debug( - f"Found a new best {val_metric}. Saving best checkpoint and weights." - ) - shutil.copyfile(filepath, str(checkpoint_path / "best.pt")) - - def load_weights(self, network_fn: Optional[Type[nn.Module]] = None) -> None: - """Load the network weights.""" - logger.debug("Loading network with pretrained weights.") - filename = glob(self.weights_filename)[0] - if not filename: - raise FileNotFoundError( - f"Could not find any pretrained weights at {self.weights_filename}" - ) - # Loading state directory. - state_dict = torch.load(filename, map_location=torch.device(self._device)) - self._network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - # Initializes the network with trained weights. - if network_fn is not None: - self._network = network_fn(**self._network_args) - self._network.load_state_dict(weights) - - if "swa_network" in state_dict: - self._swa_network = AveragedModel(self._network).to(self.device) - self._swa_network.load_state_dict(state_dict["swa_network"]) - - def save_weights(self, path: Path) -> None: - """Save the network weights.""" - logger.debug("Saving the best network weights.") - shutil.copyfile(str(path / "best.pt"), self.weights_filename) diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py deleted file mode 100644 index f9944f3..0000000 --- a/src/text_recognizer/models/character_model.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Defines the CharacterModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class CharacterModel(Model): - """Model for predicting characters from images.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image( - self, image: Union[np.ndarray, torch.Tensor] - ) -> Tuple[str, float]: - """Character prediction on an image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - logits = self.forward(image) - - prediction = self.softmax(logits.squeeze(0)) - - index = int(torch.argmax(prediction, dim=0)) - confidence_of_prediction = prediction[index] - predicted_character = self.mapper(index) - - return predicted_character, confidence_of_prediction diff --git a/src/text_recognizer/models/crnn_model.py b/src/text_recognizer/models/crnn_model.py deleted file mode 100644 index 1e01a83..0000000 --- a/src/text_recognizer/models/crnn_model.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Defines the CRNNModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CRNNModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 79] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=79, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/src/text_recognizer/models/ctc_transformer_model.py deleted file mode 100644 index 25925f2..0000000 --- a/src/text_recognizer/models/ctc_transformer_model.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Defines the CTC Transformer Model class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class CTCTransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - self.lower = dataset_args["args"]["lower"] - - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token, lower=self.lower,) - - self.tensor_transform = ToTensor() - - def criterion(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss. - - Args: - output (Tensor): Model predictions. - targets (Tensor): Correct output sequence. - - Returns: - Tensor: The CTC loss. - - """ - # Input lengths on the form [T, B] - input_lengths = torch.full( - size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long, - ) - - # Configure target tensors for ctc loss. - targets_ = Tensor([]).to(self.device) - target_lengths = [] - for t in targets: - # Remove padding symbol as it acts as the blank symbol. - t = t[t < 53] - targets_ = torch.cat([targets_, t]) - target_lengths.append(len(t)) - - targets = targets_.type(dtype=torch.long) - target_lengths = ( - torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device) - ) - - return self._criterion(output, targets, input_lengths, target_lengths) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - log_probs = self.forward(image) - - raw_pred, _ = greedy_decoder( - predictions=log_probs, - character_mapper=self.mapper, - blank_label=53, - collapse_repeated=True, - ) - - log_probs, _ = log_probs.max(dim=2) - - predicted_characters = "".join(raw_pred[0]) - confidence_of_prediction = log_probs.cumprod(dim=0)[-1].item() - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/segmentation_model.py b/src/text_recognizer/models/segmentation_model.py deleted file mode 100644 index 613108a..0000000 --- a/src/text_recognizer/models/segmentation_model.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Segmentation model for detecting and segmenting lines.""" -from typing import Callable, Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.models.base import Model - - -class SegmentationModel(Model): - """Model for segmenting lines in an image.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tensor: - """Predict on a single input.""" - self.eval() - - if image.dtype is np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype is torch.uint8 or image.dtype is torch.int64: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - if not torch.is_tensor(image): - image = Tensor(image) - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - logits = self.forward(image) - - segmentation_mask = torch.argmax(logits, dim=1) - - return segmentation_mask diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py deleted file mode 100644 index 3f63053..0000000 --- a/src/text_recognizer/models/transformer_model.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Defines the CNN-Transformer class.""" -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -from torch.utils.data import Dataset - -from text_recognizer.datasets import EmnistMapper -import text_recognizer.datasets.transforms as transforms -from text_recognizer.models.base import Model -from text_recognizer.networks import greedy_decoder - - -class TransformerModel(Model): - """Model for predicting a sequence of characters from an image of a text line with a cnn-transformer.""" - - def __init__( - self, - network_fn: str, - dataset: str, - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.init_token = dataset_args["args"]["init_token"] - self.pad_token = dataset_args["args"]["pad_token"] - self.eos_token = dataset_args["args"]["eos_token"] - self.lower = dataset_args["args"]["lower"] - self.max_len = 100 - - if self._mapper is None: - self._mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=self.lower, - ) - self.tensor_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] - ) - self.softmax = nn.Softmax(dim=2) - - @torch.no_grad() - def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: - src = self.network.extract_image_features(image) - - # Added for vqvae transformer. - if isinstance(src, Tuple): - src = src[0] - - memory = self.network.encoder(src) - - confidence_of_predictions = [] - trg_indices = [self.mapper(self.init_token)] - - for _ in range(self.max_len - 1): - trg = torch.tensor(trg_indices, device=self.device)[None, :].long() - trg = self.network.target_embedding(trg) - logits = self.network.decoder(trg=trg, memory=memory, trg_mask=None) - - # Convert logits to probabilities. - probs = self.softmax(logits) - - pred_token = probs.argmax(2)[:, -1].item() - confidence = probs.max(2).values[:, -1].item() - - trg_indices.append(pred_token) - confidence_of_predictions.append(confidence) - - if pred_token == self.mapper(self.eos_token): - break - - confidence = np.min(confidence_of_predictions) - predicted_characters = "".join([self.mapper(x) for x in trg_indices[1:]]) - - return predicted_characters, confidence - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]: - """Predict on a single input.""" - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to PyTorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - - # Rescale image between 0 and 1. - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - - (predicted_characters, confidence_of_prediction,) = self._generate_sentence( - image - ) - - return predicted_characters, confidence_of_prediction diff --git a/src/text_recognizer/models/vqvae_model.py b/src/text_recognizer/models/vqvae_model.py deleted file mode 100644 index 70f6f1f..0000000 --- a/src/text_recognizer/models/vqvae_model.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Defines the VQVAEModel class.""" -from typing import Callable, Dict, Optional, Tuple, Type, Union - -import numpy as np -import torch -from torch import nn -from torch.utils.data import Dataset -from torchvision.transforms import ToTensor - -from text_recognizer.datasets import EmnistMapper -from text_recognizer.models.base import Model - - -class VQVAEModel(Model): - """Model for reconstructing images from codebook.""" - - def __init__( - self, - network_fn: Type[nn.Module], - dataset: Type[Dataset], - network_args: Optional[Dict] = None, - dataset_args: Optional[Dict] = None, - metrics: Optional[Dict] = None, - criterion: Optional[Callable] = None, - criterion_args: Optional[Dict] = None, - optimizer: Optional[Callable] = None, - optimizer_args: Optional[Dict] = None, - lr_scheduler: Optional[Callable] = None, - lr_scheduler_args: Optional[Dict] = None, - swa_args: Optional[Dict] = None, - device: Optional[str] = None, - ) -> None: - """Initializes the CharacterModel.""" - - super().__init__( - network_fn, - dataset, - network_args, - dataset_args, - metrics, - criterion, - criterion_args, - optimizer, - optimizer_args, - lr_scheduler, - lr_scheduler_args, - swa_args, - device, - ) - self.pad_token = dataset_args["args"]["pad_token"] - if self._mapper is None: - self._mapper = EmnistMapper(pad_token=self.pad_token,) - self.tensor_transform = ToTensor() - self.softmax = nn.Softmax(dim=0) - - @torch.no_grad() - def predict_on_image(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: - """Reconstruction of image. - - Args: - image (Union[np.ndarray, torch.Tensor]): An image containing a character. - - Returns: - Tuple[str, float]: The predicted character and the confidence in the prediction. - - """ - self.eval() - - if image.dtype == np.uint8: - # Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1]. - image = self.tensor_transform(image) - if image.dtype == torch.uint8: - # If the image is an unscaled tensor. - image = image.type("torch.FloatTensor") / 255 - - # Put the image tensor on the device the model weights are on. - image = image.to(self.device) - image_reconstructed, _ = self.forward(image) - - return image_reconstructed diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py deleted file mode 100644 index 1521355..0000000 --- a/src/text_recognizer/networks/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Network modules.""" -from .cnn import CNN -from .cnn_transformer import CNNTransformer -from .crnn import ConvolutionalRecurrentNetwork -from .ctc import greedy_decoder -from .densenet import DenseNet -from .lenet import LeNet -from .metrics import accuracy, cer, wer -from .mlp import MLP -from .residual_network import ResidualNetwork, ResidualNetworkEncoder -from .transducer import load_transducer_loss, TDS2d -from .transformer import Transformer -from .unet import UNet -from .util import sliding_window -from .vit import ViT -from .vq_transformer import VQTransformer -from .vqvae import VQVAE -from .wide_resnet import WideResidualNetwork - -__all__ = [ - "accuracy", - "cer", - "CNN", - "CNNTransformer", - "ConvolutionalRecurrentNetwork", - "DenseNet", - "FCN", - "greedy_decoder", - "MLP", - "LeNet", - "load_transducer_loss", - "ResidualNetwork", - "ResidualNetworkEncoder", - "sliding_window", - "UNet", - "TDS2d", - "Transformer", - "ViT", - "VQTransformer", - "VQVAE", - "wer", - "WideResidualNetwork", -] diff --git a/src/text_recognizer/networks/beam.py b/src/text_recognizer/networks/beam.py deleted file mode 100644 index dccccdb..0000000 --- a/src/text_recognizer/networks/beam.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Implementation of beam search decoder for a sequence to sequence network. - -Stolen from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py - -""" -# from typing import List -# from Queue import PriorityQueue - -# from loguru import logger -# import torch -# from torch import nn -# from torch import Tensor -# import torch.nn.functional as F - - -# class Node: -# def __init__( -# self, parent: Node, target_index: int, log_prob: Tensor, length: int -# ) -> None: -# self.parent = parent -# self.target_index = target_index -# self.log_prob = log_prob -# self.length = length -# self.reward = 0.0 - -# def eval(self, alpha: float = 1.0) -> Tensor: -# return self.log_prob / (self.length - 1 + 1e-6) + alpha * self.reward - - -# @torch.no_grad() -# def beam_decoder( -# network, mapper, device, memory: Tensor = None, max_len: int = 97, -# ) -> Tensor: -# beam_width = 10 -# topk = 1 # How many sentences to generate. - -# trg_indices = [mapper(mapper.init_token)] - -# end_nodes = [] - -# node = Node(None, trg_indices, 0, 1) -# nodes = PriorityQueue() - -# nodes.put((node.eval(), node)) -# q_size = 1 - -# # Beam search -# for _ in range(max_len): -# if q_size > 2000: -# logger.warning("Could not decoder input") -# break - -# # Fetch the best node. -# score, n = nodes.get() -# decoder_input = n.target_index - -# if n.target_index == mapper(mapper.eos_token) and n.parent is not None: -# end_nodes.append((score, n)) - -# # If we reached the maximum number of sentences required. -# if len(end_nodes) >= 1: -# break -# else: -# continue - -# # Forward pass with transformer. -# trg = torch.tensor(trg_indices, device=device)[None, :].long() -# trg = network.target_embedding(trg) -# logits = network.decoder(trg=trg, memory=memory, trg_mask=None) -# log_prob = F.log_softmax(logits, dim=2) - -# log_prob, indices = torch.topk(log_prob, beam_width) - -# for new_k in range(beam_width): -# # TODO: continue from here -# token_index = indices[0][new_k].view(1, -1) -# log_p = log_prob[0][new_k].item() - -# node = Node() - -# pass - -# pass diff --git a/src/text_recognizer/networks/cnn.py b/src/text_recognizer/networks/cnn.py deleted file mode 100644 index 1807bb9..0000000 --- a/src/text_recognizer/networks/cnn.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Implementation of a simple backbone cnn network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class CNN(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64, 128), - kernel_sizes: Tuple[int, ...] = (4, 4, 4), - strides: Tuple[int, ...] = (2, 2, 2), - max_pool_kernel: int = 2, - dropout_rate: float = 0.2, - activation: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - strides (Tuple[int, ...]): Stride length of the convolutional filter. Defaults to (2, 2, 2). - max_pool_kernel (int): 2D max pooling kernel. Defaults to 2. - dropout_rate (float): The dropout rate. Defaults to 0.2. - activation (Optional[str]): The name of non-linear activation function. Defaults to relu. - - Raises: - RuntimeError: if the number of hyperparameters does not match in length. - - """ - super().__init__() - - if len(channels) - 1 != len(kernel_sizes) and len(kernel_sizes) != len(strides): - raise RuntimeError("The number of the hyperparameters does not match.") - - self.cnn = self._build_network( - channels, kernel_sizes, strides, max_pool_kernel, dropout_rate, activation, - ) - - def _build_network( - self, - channels: Tuple[int, ...], - kernel_sizes: Tuple[int, ...], - strides: Tuple[int, ...], - max_pool_kernel: int, - dropout_rate: float, - activation: str, - ) -> nn.Sequential: - # Load activation function. - activation_fn = activation_function(activation) - - channels = list(channels) - in_channels = channels.pop(0) - configuration = zip(channels, kernel_sizes, strides) - - modules = nn.ModuleList([]) - - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - # Add max pool to reduce output size. - if i == len(channels) // 2: - modules.append(nn.MaxPool2d(max_pool_kernel)) - if i == 0: - modules.append( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ) - ) - else: - modules.append( - nn.Sequential( - activation_fn, - nn.BatchNorm2d(in_channels), - nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - ) - ) - - if dropout_rate: - modules.append(nn.Dropout2d(p=dropout_rate)) - - in_channels = out_channels - - return nn.Sequential(*modules) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.cnn(x) diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py deleted file mode 100644 index a2d7926..0000000 --- a/src/text_recognizer/networks/cnn_transformer.py +++ /dev/null @@ -1,158 +0,0 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone - - -class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - pool_kernel: Optional[Tuple[int, int]] = None, - ) -> None: - super().__init__() - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.backbone = configure_backbone(backbone, backbone_args) - - if pool_kernel is not None: - self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) - else: - self.max_pool = None - - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.pos_dropout = nn.Dropout(p=dropout_rate) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None - ) - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential( - # nn.Linear(hidden_dim, hidden_dim * 2), - # activation_function(activation), - nn.Linear(hidden_dim, vocab_size), - ) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - src = self.backbone(src) - - if self.max_pool is not None: - src = self.max_pool(src) - - if self.adaptive_pool is not None and len(src.shape) == 4: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - elif len(src.shape) == 4: - src = rearrange(src, "b c h w -> b (h w) c") - - b, t, _ = src.shape - - src += self.src_position_embedding[:, :t] - src = self.pos_dropout(src) - - return src - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - trg = self.character_embedding(trg.long()) - trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py deleted file mode 100644 index 778e232..0000000 --- a/src/text_recognizer/networks/crnn.py +++ /dev/null @@ -1,110 +0,0 @@ -"""CRNN for handwritten text recognition.""" -from typing import Dict, Tuple - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange -from loguru import logger -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import configure_backbone - - -class ConvolutionalRecurrentNetwork(nn.Module): - """Network that takes a image of a text line and predicts tokens that are in the image.""" - - def __init__( - self, - backbone: str, - backbone_args: Dict = None, - input_size: int = 128, - hidden_size: int = 128, - bidirectional: bool = False, - num_layers: int = 1, - num_classes: int = 80, - patch_size: Tuple[int, int] = (28, 28), - stride: Tuple[int, int] = (1, 14), - recurrent_cell: str = "lstm", - avg_pool: bool = False, - use_sliding_window: bool = True, - ) -> None: - super().__init__() - self.backbone_args = backbone_args or {} - self.patch_size = patch_size - self.stride = stride - self.sliding_window = ( - self._configure_sliding_window() if use_sliding_window else None - ) - self.input_size = input_size - self.hidden_size = hidden_size - self.backbone = configure_backbone(backbone, backbone_args) - self.bidirectional = bidirectional - self.avg_pool = avg_pool - - if recurrent_cell.upper() in ["LSTM", "GRU"]: - recurrent_cell = getattr(nn, recurrent_cell) - else: - logger.warning( - f"Option {recurrent_cell} not valid, defaulting to LSTM cell." - ) - recurrent_cell = nn.LSTM - - self.rnn = recurrent_cell( - input_size=self.input_size, - hidden_size=self.hidden_size, - bidirectional=bidirectional, - num_layers=num_layers, - ) - - decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size - - self.decoder = nn.Sequential( - nn.Linear(in_features=decoder_size, out_features=num_classes), - nn.LogSoftmax(dim=2), - ) - - def _configure_sliding_window(self) -> nn.Sequential: - return nn.Sequential( - nn.Unfold(kernel_size=self.patch_size, stride=self.stride), - Rearrange( - "b (c h w) t -> b t c h w", - h=self.patch_size[0], - w=self.patch_size[1], - c=1, - ), - ) - - def forward(self, x: Tensor) -> Tensor: - """Converts images to sequence of patches, feeds them to a CNN, then predictions are made with an LSTM.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - - if self.sliding_window is not None: - # Create image patches with a sliding window kernel. - x = self.sliding_window(x) - - # Rearrange from a sequence of patches for feedforward network. - b, t = x.shape[:2] - x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - - x = self.backbone(x) - - # Average pooling. - if self.avg_pool: - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) - else: - x = rearrange(x, "(b t) h -> t b h", b=b, t=t) - else: - # Encode the entire image with a CNN, and use the channels as temporal dimension. - x = self.backbone(x) - x = rearrange(x, "b c h w -> b w c h") - if self.adaptive_pool is not None: - x = self.adaptive_pool(x) - x = x.squeeze(3) - - # Sequence predictions. - x, _ = self.rnn(x) - - # Sequence to classification layer. - x = self.decoder(x) - return x diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py deleted file mode 100644 index af9b700..0000000 --- a/src/text_recognizer/networks/ctc.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Decodes the CTC output.""" -from typing import Callable, List, Optional, Tuple - -from einops import rearrange -import torch -from torch import Tensor - -from text_recognizer.datasets.util import EmnistMapper - - -def greedy_decoder( - predictions: Tensor, - targets: Optional[Tensor] = None, - target_lengths: Optional[Tensor] = None, - character_mapper: Optional[Callable] = None, - blank_label: int = 79, - collapse_repeated: bool = True, -) -> Tuple[List[str], List[str]]: - """Greedy CTC decoder. - - Args: - predictions (Tensor): Tenor of network predictions, shape [time, batch, classes]. - targets (Optional[Tensor]): Target tensor, shape is [batch, targets]. Defaults to None. - target_lengths (Optional[Tensor]): Length of each target tensor. Defaults to None. - character_mapper (Optional[Callable]): A emnist/character mapper for mapping integers to characters. Defaults - to None. - blank_label (int): The blank character to be ignored. Defaults to 80. - collapse_repeated (bool): Collapase consecutive predictions of the same character. Defaults to True. - - Returns: - Tuple[List[str], List[str]]: Tuple of decoded predictions and decoded targets. - - """ - - if character_mapper is None: - character_mapper = EmnistMapper(pad_token="_") # noqa: S106 - - predictions = rearrange(torch.argmax(predictions, dim=2), "t b -> b t") - decoded_predictions = [] - decoded_targets = [] - for i, prediction in enumerate(predictions): - decoded_prediction = [] - decoded_target = [] - if targets is not None and target_lengths is not None: - for target_index in targets[i][: target_lengths[i]]: - if target_index == blank_label: - continue - decoded_target.append(character_mapper(int(target_index))) - decoded_targets.append(decoded_target) - for j, index in enumerate(prediction): - if index != blank_label: - if collapse_repeated and j != 0 and index == prediction[j - 1]: - continue - decoded_prediction.append(index.item()) - decoded_predictions.append( - [character_mapper(int(pred_index)) for pred_index in decoded_prediction] - ) - return decoded_predictions, decoded_targets diff --git a/src/text_recognizer/networks/densenet.py b/src/text_recognizer/networks/densenet.py deleted file mode 100644 index 7dc58d9..0000000 --- a/src/text_recognizer/networks/densenet.py +++ /dev/null @@ -1,225 +0,0 @@ -"""Defines a Densely Connected Convolutional Networks in PyTorch. - -Sources: -https://arxiv.org/abs/1608.06993 -https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py - -""" -from typing import List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _DenseLayer(nn.Module): - """A dense layer with pre-batch norm -> activation function -> Conv-layer x 2.""" - - def __init__( - self, - in_channels: int, - growth_rate: int, - bn_size: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.dense_layer = [ - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=bn_size * growth_rate, - kernel_size=1, - stride=1, - bias=False, - ), - nn.BatchNorm2d(bn_size * growth_rate), - activation_fn, - nn.Conv2d( - in_channels=bn_size * growth_rate, - out_channels=growth_rate, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - ] - if dropout_rate: - self.dense_layer.append(nn.Dropout(p=dropout_rate)) - - self.dense_layer = nn.Sequential(*self.dense_layer) - - def forward(self, x: Union[Tensor, List[Tensor]]) -> Tensor: - if isinstance(x, list): - x = torch.cat(x, 1) - return self.dense_layer(x) - - -class _DenseBlock(nn.Module): - def __init__( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.dense_block = self._build_dense_blocks( - num_layers, in_channels, bn_size, growth_rate, dropout_rate, activation, - ) - - def _build_dense_blocks( - self, - num_layers: int, - in_channels: int, - bn_size: int, - growth_rate: int, - dropout_rate: float, - activation: str = "relu", - ) -> nn.ModuleList: - dense_block = [] - for i in range(num_layers): - dense_block.append( - _DenseLayer( - in_channels=in_channels + i * growth_rate, - growth_rate=growth_rate, - bn_size=bn_size, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - return nn.ModuleList(dense_block) - - def forward(self, x: Tensor) -> Tensor: - feature_maps = [x] - for layer in self.dense_block: - x = layer(feature_maps) - feature_maps.append(x) - return torch.cat(feature_maps, 1) - - -class _Transition(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu", - ) -> None: - super().__init__() - activation_fn = activation_function(activation) - self.transition = nn.Sequential( - nn.BatchNorm2d(in_channels), - activation_fn, - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=1, - bias=False, - ), - nn.AvgPool2d(kernel_size=2, stride=2), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.transition(x) - - -class DenseNet(nn.Module): - """Implementation of Densenet, a network archtecture that concats previous layers for maximum infomation flow.""" - - def __init__( - self, - growth_rate: int = 32, - block_config: List[int] = (6, 12, 24, 16), - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 80, - bn_size: int = 4, - dropout_rate: float = 0, - classifier: bool = True, - activation: str = "relu", - ) -> None: - super().__init__() - self.densenet = self._configure_densenet( - in_channels, - base_channels, - num_classes, - growth_rate, - block_config, - bn_size, - dropout_rate, - classifier, - activation, - ) - - def _configure_densenet( - self, - in_channels: int, - base_channels: int, - num_classes: int, - growth_rate: int, - block_config: List[int], - bn_size: int, - dropout_rate: float, - classifier: bool, - activation: str, - ) -> nn.Sequential: - activation_fn = activation_function(activation) - densenet = [ - nn.Conv2d( - in_channels=in_channels, - out_channels=base_channels, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(base_channels), - activation_fn, - ] - - num_features = base_channels - - for i, num_layers in enumerate(block_config): - densenet.append( - _DenseBlock( - num_layers=num_layers, - in_channels=num_features, - bn_size=bn_size, - growth_rate=growth_rate, - dropout_rate=dropout_rate, - activation=activation, - ) - ) - num_features = num_features + num_layers * growth_rate - if i != len(block_config) - 1: - densenet.append( - _Transition( - in_channels=num_features, - out_channels=num_features // 2, - activation=activation, - ) - ) - num_features = num_features // 2 - - densenet.append(activation_fn) - - if classifier: - densenet.append(nn.AdaptiveAvgPool2d((1, 1))) - densenet.append(Rearrange("b c h w -> b (c h w)")) - densenet.append( - nn.Linear(in_features=num_features, out_features=num_classes) - ) - - return nn.Sequential(*densenet) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass of Densenet.""" - # If batch dimenstion is missing, it will be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.densenet(x) diff --git a/src/text_recognizer/networks/lenet.py b/src/text_recognizer/networks/lenet.py deleted file mode 100644 index 527e1a0..0000000 --- a/src/text_recognizer/networks/lenet.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Implementation of the LeNet network.""" -from typing import Callable, Dict, Optional, Tuple - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class LeNet(nn.Module): - """LeNet network for character prediction.""" - - def __init__( - self, - channels: Tuple[int, ...] = (1, 32, 64), - kernel_sizes: Tuple[int, ...] = (3, 3, 2), - hidden_size: Tuple[int, ...] = (9216, 128), - dropout_rate: float = 0.2, - num_classes: int = 10, - activation_fn: Optional[str] = "relu", - ) -> None: - """Initialization of the LeNet network. - - Args: - channels (Tuple[int, ...]): Channels in the convolutional layers. Defaults to (1, 32, 64). - kernel_sizes (Tuple[int, ...]): Kernel sizes in the convolutional layers. Defaults to (3, 3, 2). - hidden_size (Tuple[int, ...]): Size of the flattend output form the convolutional layers. - Defaults to (9216, 128). - dropout_rate (float): The dropout rate. Defaults to 0.2. - num_classes (int): Number of classes. Defaults to 10. - activation_fn (Optional[str]): The name of non-linear activation function. Defaults to relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - self.layers = [ - nn.Conv2d( - in_channels=channels[0], - out_channels=channels[1], - kernel_size=kernel_sizes[0], - ), - activation_fn, - nn.Conv2d( - in_channels=channels[1], - out_channels=channels[2], - kernel_size=kernel_sizes[1], - ), - activation_fn, - nn.MaxPool2d(kernel_sizes[2]), - nn.Dropout(p=dropout_rate), - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=hidden_size[0], out_features=hidden_size[1]), - activation_fn, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=hidden_size[1], out_features=num_classes), - ] - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) diff --git a/src/text_recognizer/networks/loss/__init__.py b/src/text_recognizer/networks/loss/__init__.py deleted file mode 100644 index b489264..0000000 --- a/src/text_recognizer/networks/loss/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Loss module.""" -from .loss import EmbeddingLoss, LabelSmoothingCrossEntropy diff --git a/src/text_recognizer/networks/loss/loss.py b/src/text_recognizer/networks/loss/loss.py deleted file mode 100644 index cf9fa0d..0000000 --- a/src/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Implementations of custom loss functions.""" -from pytorch_metric_learning import distances, losses, miners, reducers -import torch -from torch import nn -from torch import Tensor -from torch.autograd import Variable -import torch.nn.functional as F - -__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] - - -class EmbeddingLoss: - """Metric loss for training encoders to produce information-rich latent embeddings.""" - - def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: - self.distance = distances.CosineSimilarity() - self.reducer = reducers.ThresholdReducer(low=0) - self.loss_fn = losses.TripletMarginLoss( - margin=margin, distance=self.distance, reducer=self.reducer - ) - self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) - - def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: - """Computes the metric loss for the embeddings based on their labels. - - Args: - embeddings (Tensor): The laten vectors encoded by the network. - labels (Tensor): Labels of the embeddings. - - Returns: - Tensor: The metric loss for the embeddings. - - """ - hard_pairs = self.miner(embeddings, labels) - loss = self.loss_fn(embeddings, labels, hard_pairs) - return loss - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py deleted file mode 100644 index 2605731..0000000 --- a/src/text_recognizer/networks/metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Utility functions for models.""" -from typing import Optional - -from einops import rearrange -import Levenshtein as Lev -import torch -from torch import Tensor - -from text_recognizer.networks import greedy_decoder - - -def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: - """Computes the accuracy. - - Args: - outputs (Tensor): The output from the network. - labels (Tensor): Ground truth labels. - pad_index (int): Padding index. - - Returns: - float: The accuracy for the batch. - - """ - - _, predicted = torch.max(outputs, dim=-1) - - # Mask out the pad tokens - mask = labels != pad_index - - predicted *= mask - labels *= mask - - acc = (predicted == labels).sum().float() / labels.shape[0] - acc = acc.item() - return acc - - -def cer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the character error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (Optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The cer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - prediction, target = ( - prediction.replace(" ", ""), - target.replace(" ", ""), - ) - lev_dist += Lev.distance(prediction, target) - return lev_dist / len(decoded_predictions) - - -def wer( - outputs: Tensor, - targets: Tensor, - batch_size: Optional[int] = None, - blank_label: Optional[int] = int, -) -> float: - """Computes the Word error rate. - - Args: - outputs (Tensor): The output from the network. - targets (Tensor): Ground truth labels. - batch_size (optional[int]): Batch size if target and output has been flattend. - blank_label (Optional[int]): The blank character to be ignored. Defaults to 79. - - Returns: - float: The wer for the batch. - - """ - if len(outputs.shape) == 2 and len(targets.shape) == 1 and batch_size is not None: - targets = rearrange(targets, "(b t) -> b t", b=batch_size) - outputs = rearrange(outputs, "(b t) v -> t b v", b=batch_size) - - target_lengths = torch.full( - size=(outputs.shape[1],), fill_value=targets.shape[1], dtype=torch.long, - ) - decoded_predictions, decoded_targets = greedy_decoder( - outputs, targets, target_lengths, blank_label=blank_label, - ) - - lev_dist = 0 - - for prediction, target in zip(decoded_predictions, decoded_targets): - prediction = "".join(prediction) - target = "".join(target) - - b = set(prediction.split() + target.split()) - word2char = dict(zip(b, range(len(b)))) - - w1 = [chr(word2char[w]) for w in prediction.split()] - w2 = [chr(word2char[w]) for w in target.split()] - - lev_dist += Lev.distance("".join(w1), "".join(w2)) - - return lev_dist / len(decoded_predictions) diff --git a/src/text_recognizer/networks/mlp.py b/src/text_recognizer/networks/mlp.py deleted file mode 100644 index 1101912..0000000 --- a/src/text_recognizer/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Defines the MLP network.""" -from typing import Callable, Dict, List, Optional, Union - -from einops.layers.torch import Rearrange -import torch -from torch import nn - -from text_recognizer.networks.util import activation_function - - -class MLP(nn.Module): - """Multi layered perceptron network.""" - - def __init__( - self, - input_size: int = 784, - num_classes: int = 10, - hidden_size: Union[int, List] = 128, - num_layers: int = 3, - dropout_rate: float = 0.2, - activation_fn: str = "relu", - ) -> None: - """Initialization of the MLP network. - - Args: - input_size (int): The input shape of the network. Defaults to 784. - num_classes (int): Number of classes in the dataset. Defaults to 10. - hidden_size (Union[int, List]): The number of `neurons` in each hidden layer. Defaults to 128. - num_layers (int): The number of hidden layers. Defaults to 3. - dropout_rate (float): The dropout rate at each layer. Defaults to 0.2. - activation_fn (str): Name of the activation function in the hidden layers. Defaults to - relu. - - """ - super().__init__() - - activation_fn = activation_function(activation_fn) - - if isinstance(hidden_size, int): - hidden_size = [hidden_size] * num_layers - - self.layers = [ - Rearrange("b c h w -> b (c h w)"), - nn.Linear(in_features=input_size, out_features=hidden_size[0]), - activation_fn, - ] - - for i in range(num_layers - 1): - self.layers += [ - nn.Linear(in_features=hidden_size[i], out_features=hidden_size[i + 1]), - activation_fn, - ] - - if dropout_rate: - self.layers.append(nn.Dropout(p=dropout_rate)) - - self.layers.append( - nn.Linear(in_features=hidden_size[-1], out_features=num_classes) - ) - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """The feedforward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - return self.layers(x) - - @property - def __name__(self) -> str: - """Returns the name of the network.""" - return "mlp" diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py deleted file mode 100644 index c33f419..0000000 --- a/src/text_recognizer/networks/residual_network.py +++ /dev/null @@ -1,310 +0,0 @@ -"""Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Rearrange, Reduce -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class Conv2dAuto(nn.Conv2d): - """Convolution with auto padding based on kernel size.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) - - -def conv_bn(in_channels: int, out_channels: int, *args, **kwargs) -> nn.Sequential: - """3x3 convolution with batch norm.""" - conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False,) - return nn.Sequential( - conv3x3(in_channels, out_channels, *args, **kwargs), - nn.BatchNorm2d(out_channels), - ) - - -class IdentityBlock(nn.Module): - """Residual with identity block.""" - - def __init__( - self, in_channels: int, out_channels: int, activation: str = "relu" - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.blocks = nn.Identity() - self.activation_fn = activation_function(activation) - self.shortcut = nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self.apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - x = self.activation_fn(x) - return x - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.out_channels - - -class ResidualBlock(IdentityBlock): - """Residual with nonlinear shortcut.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - expansion: int = 1, - downsampling: int = 1, - *args, - **kwargs - ) -> None: - """Short summary. - - Args: - in_channels (int): Number of in channels. - out_channels (int): umber of out channels. - expansion (int): Expansion factor of the out channels. Defaults to 1. - downsampling (int): Downsampling factor used in stride. Defaults to 1. - *args (type): Extra arguments. - **kwargs (type): Extra key value arguments. - - """ - super().__init__(in_channels, out_channels, *args, **kwargs) - self.expansion = expansion - self.downsampling = downsampling - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.expanded_channels, - kernel_size=1, - stride=self.downsampling, - bias=False, - ), - nn.BatchNorm2d(self.expanded_channels), - ) - if self.apply_shortcut - else None - ) - - @property - def expanded_channels(self) -> int: - """Computes the expanded output channels.""" - return self.out_channels * self.expansion - - @property - def apply_shortcut(self) -> bool: - """Check if shortcut should be applied.""" - return self.in_channels != self.expanded_channels - - -class BasicBlock(ResidualBlock): - """Basic ResNet block.""" - - expansion = 1 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - bias=False, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - bias=False, - ), - ) - - -class BottleNeckBlock(ResidualBlock): - """Bottleneck block to increase depth while minimizing parameter size.""" - - expansion = 4 - - def __init__(self, in_channels: int, out_channels: int, *args, **kwargs) -> None: - super().__init__(in_channels, out_channels, *args, **kwargs) - self.blocks = nn.Sequential( - conv_bn( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=1, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.downsampling, - ), - self.activation_fn, - conv_bn( - in_channels=self.out_channels, - out_channels=self.expanded_channels, - kernel_size=1, - ), - ) - - -class ResidualLayer(nn.Module): - """ResNet layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - block: BasicBlock = BasicBlock, - num_blocks: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - downsampling = 2 if in_channels != out_channels else 1 - self.blocks = nn.Sequential( - block( - in_channels, out_channels, *args, **kwargs, downsampling=downsampling - ), - *[ - block( - out_channels * block.expansion, - out_channels, - downsampling=1, - *args, - **kwargs - ) - for _ in range(num_blocks - 1) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.blocks(x) - return x - - -class ResidualNetworkEncoder(nn.Module): - """Encoder network.""" - - def __init__( - self, - in_channels: int = 1, - block_sizes: Union[int, List[int]] = (32, 64), - depths: Union[int, List[int]] = (2, 2), - activation: str = "relu", - block: Type[nn.Module] = BasicBlock, - levels: int = 1, - *args, - **kwargs - ) -> None: - super().__init__() - self.block_sizes = ( - block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels - ) - self.depths = depths if isinstance(depths, list) else [depths] * levels - self.activation = activation - self.gate = nn.Sequential( - nn.Conv2d( - in_channels=in_channels, - out_channels=self.block_sizes[0], - kernel_size=7, - stride=2, - padding=1, - bias=False, - ), - nn.BatchNorm2d(self.block_sizes[0]), - activation_function(self.activation), - # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), - ) - - self.blocks = self._configure_blocks(block) - - def _configure_blocks( - self, block: Type[nn.Module], *args, **kwargs - ) -> nn.Sequential: - channels = [self.block_sizes[0]] + list( - zip(self.block_sizes, self.block_sizes[1:]) - ) - blocks = [ - ResidualLayer( - in_channels=channels[0], - out_channels=channels[0], - num_blocks=self.depths[0], - block=block, - activation=self.activation, - *args, - **kwargs - ) - ] - blocks += [ - ResidualLayer( - in_channels=in_channels * block.expansion, - out_channels=out_channels, - num_blocks=num_blocks, - block=block, - activation=self.activation, - *args, - **kwargs - ) - for (in_channels, out_channels), num_blocks in zip( - channels[1:], self.depths[1:] - ) - ] - - return nn.Sequential(*blocks) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - # If batch dimenstion is missing, it needs to be added. - if len(x.shape) == 3: - x = x.unsqueeze(0) - x = self.gate(x) - x = self.blocks(x) - return x - - -class ResidualNetworkDecoder(nn.Module): - """Classification head.""" - - def __init__(self, in_features: int, num_classes: int = 80) -> None: - super().__init__() - self.decoder = nn.Sequential( - Reduce("b c h w -> b c", "mean"), - nn.Linear(in_features=in_features, out_features=num_classes), - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - return self.decoder(x) - - -class ResidualNetwork(nn.Module): - """Full residual network.""" - - def __init__(self, in_channels: int, num_classes: int, *args, **kwargs) -> None: - super().__init__() - self.encoder = ResidualNetworkEncoder(in_channels, *args, **kwargs) - self.decoder = ResidualNetworkDecoder( - in_features=self.encoder.blocks[-1].blocks[-1].expanded_channels, - num_classes=num_classes, - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - x = self.encoder(x) - x = self.decoder(x) - return x diff --git a/src/text_recognizer/networks/stn.py b/src/text_recognizer/networks/stn.py deleted file mode 100644 index e9d216f..0000000 --- a/src/text_recognizer/networks/stn.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Spatial Transformer Network.""" - -from einops.layers.torch import Rearrange -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class SpatialTransformerNetwork(nn.Module): - """A network with differentiable attention. - - Network that learns how to perform spatial transformations on the input image in order to enhance the - geometric invariance of the model. - - # TODO: add arguments to make it more general. - - """ - - def __init__(self) -> None: - super().__init__() - # Initialize the identity transformation and its weights and biases. - linear = nn.Linear(32, 3 * 2) - linear.weight.data.zero_() - linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) - - self.theta = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=8, kernel_size=7), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=8, out_channels=10, kernel_size=5), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.ReLU(inplace=True), - Rearrange("b c h w -> b (c h w)", h=3, w=3), - nn.Linear(in_features=10 * 3 * 3, out_features=32), - nn.ReLU(inplace=True), - linear, - Rearrange("b (row col) -> b row col", row=2, col=3), - ) - - def forward(self, x: Tensor) -> Tensor: - """The spatial transformation.""" - grid = F.affine_grid(self.theta(x), x.shape) - return F.grid_sample(x, grid, align_corners=False) diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/src/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/src/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: - https://arxiv.org/abs/1904.02619 - https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: - https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): - """Internal block of a 2D TDSC network.""" - - def __init__( - self, - in_channels: int, - img_depth: int, - kernel_size: Tuple[int], - dropout_rate: float, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.img_depth = img_depth - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.fc_dim = in_channels * img_depth - - # Network placeholders. - self.conv = None - self.mlp = None - self.instance_norm = None - - self._build_block() - - def _build_block(self) -> None: - # Convolutional block. - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), - padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - ) - - # MLP block. - self.mlp = nn.Sequential( - nn.Linear(self.fc_dim, self.fc_dim), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.Linear(self.fc_dim, self.fc_dim), - nn.Dropout(self.dropout_rate), - ) - - # Instance norm. - self.instance_norm = nn.ModuleList( - [ - nn.InstanceNorm2d(self.fc_dim, affine=True), - nn.InstanceNorm2d(self.fc_dim, affine=True), - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, CD, H, W)` - - Returns: - Tensor: Output tensor. - - """ - B, CD, H, W = x.shape - C, D = self.in_channels, self.img_depth - residual = x - x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) - x = self.conv(x) - x = rearrange(x, "b c d h w -> b (c d) h w") - x += residual - - x = self.instance_norm[0](x) - - x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x - x + self.instance_norm[1](x) - - # Output shape: [B, CD, H, W] - return x - - -class TDS2d(nn.Module): - """TDS Netowrk. - - Structure is the following: - Downsample layer -> TDS2d group -> ... -> Linear output layer - - - """ - - def __init__( - self, - input_dim: int, - output_dim: int, - depth: int, - tds_groups: Tuple[int], - kernel_size: Tuple[int], - dropout_rate: float, - in_channels: int = 1, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.input_dim = input_dim - self.output_dim = output_dim - self.depth = depth - self.tds_groups = tds_groups - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - - self.tds = None - self.fc = None - - self._build_network() - - def _build_network(self) -> None: - in_channels = self.in_channels - modules = [] - stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) - if self.input_dim % stride_h: - raise RuntimeError( - f"Image height not divisible by total stride {stride_h}." - ) - - for tds_group in self.tds_groups: - # Add downsample layer. - out_channels = self.depth * tds_group["channels"] - modules.extend( - [ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), - stride=tds_group["stride"], - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.InstanceNorm2d(out_channels, affine=True), - ] - ) - - for _ in range(tds_group["num_blocks"]): - modules.append( - TDSBlock2d( - tds_group["channels"], - self.depth, - self.kernel_size, - self.dropout_rate, - ) - ) - - in_channels = out_channels - - self.tds = nn.Sequential(*modules) - self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, H, W)` - - Returns: - Tensor: Output tensor. - - """ - if len(x.shape) == 4: - x = x.squeeze(1) # Squeeze the channel dim away. - - B, H, W = x.shape - x = rearrange( - x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels - ) - x = self.tds(x) - - # x shape: [B, C, H, W] - x = rearrange(x, "b c h w -> b w (c h)") - - return self.fc(x) diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/src/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): - def test_viterbi(self): - T = 5 - N = 4 - B = 2 - - # fmt: off - emissions1 = torch.tensor(( - 0, 4, 0, 1, - 0, 2, 1, 1, - 0, 0, 0, 2, - 0, 0, 0, 2, - 8, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - emissions2 = torch.tensor(( - 0, 2, 1, 7, - 0, 2, 9, 1, - 0, 0, 0, 2, - 0, 0, 5, 2, - 1, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - # fmt: on - - # Test without blank: - labels = [[1, 3, 0], [3, 2, 3, 2, 3]] - transducer = Transducer( - tokens=["a", "b", "c", "d"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, - blank="none", - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - # Test with blank without repeats: - labels = [[1, 0], [2, 2]] - transducer = Transducer( - tokens=["a", "b", "c"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2}, - blank="optional", - allow_repeats=False, - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/src/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: - scalar = gtn.Graph() - scalar.add_node(True) - scalar.add_node(False, True) - scalar.add_arc(0, 1, 0, 0, weight) - return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: - graph = gtn.Graph(False) - graph.add_node(True) - for i, s in enumerate(sequence): - graph.add_node(False, i == (len(sequence) - 1)) - graph.add_arc(i, i + 1, s) - return graph - - -def make_transitions_graph( - ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: - transitions = gtn.Graph(calc_grad) - transitions.add_node(True, ngram == 1) - - state_map = {(): 0} - - # First build transitions which include <s>: - for n in range(1, ngram): - for state in itertools.product(range(num_tokens), repeat=n): - in_idx = state_map[state[:-1]] - out_idx = transitions.add_node(False, ngram == 1) - state_map[state] = out_idx - transitions.add_arc(in_idx, out_idx, state[-1]) - - for state in itertools.product(range(num_tokens), repeat=ngram): - state_idx = state_map[state[:-1]] - new_state_idx = state_map[state[1:]] - # p(state[-1] | state[:-1]) - transitions.add_arc(state_idx, new_state_idx, state[-1]) - - if ngram > 1: - # Build transitions which include </s>: - end_idx = transitions.add_node(False, True) - for in_idx in range(end_idx): - transitions.add_arc(in_idx, end_idx, gtn.epsilon) - - return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: - """Constructs a graph which transduces letters to word pieces.""" - graph = gtn.Graph(False) - graph.add_node(True, True) - for i, wp in enumerate(word_pieces): - prev = 0 - for l in wp[:-1]: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) - graph.arc_sort() - return graph - - -def make_token_graph( - token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: - """Constructs a graph with all the individual token transition models.""" - if not allow_repeats and blank != "optional": - raise ValueError("Must use blank='optional' if disallowing repeats.") - - ntoks = len(token_list) - graph = gtn.Graph(False) - - # Creating nodes - graph.add_node(True, True) - for i in range(ntoks): - # We can consume one or more consecutive word - # pieces for each emission: - # E.g. [ab, ab, ab] transduces to [ab] - graph.add_node(False, blank != "forced") - - if blank != "none": - graph.add_node() - - # Creating arcs - if blank != "none": - # Blank index is assumed to be last (ntoks) - graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) - graph.add_arc(ntoks + 1, 0, gtn.epsilon) - - for i in range(ntoks): - graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) - graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - - if allow_repeats: - if blank == "forced": - # Allow transitions from token to blank only - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - else: - # Allow transition from token to blank and all other tokens - graph.add_arc(i + 1, 0, gtn.epsilon) - - else: - # allow transitions to blank and all other tokens except the same token - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - for j in range(ntoks): - if i != j: - graph.add_arc(i + 1, j + 1, j, j) - - return graph - - -class TransducerLossFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - inputs, - targets, - tokens, - lexicon, - transition_params=None, - transitions=None, - reduction="none", - ) -> Tensor: - B, T, C = inputs.shape - - losses = [None] * B - emissions_graphs = [None] * B - - if transitions is not None: - if transition_params is None: - raise ValueError("Specified transitions, but not transition params.") - - cpu_data = transition_params.cpu().contiguous() - transitions.set_weights(cpu_data.data_ptr()) - transitions.calc_grad = transition_params.requires_grad - transitions.zero_grad() - - def process(b: int) -> None: - # Create emission graph: - emissions = gtn.linear_graph(T, C, inputs.requires_grad) - cpu_data = inputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - target = make_chain_graph(targets[b]) - target.arc_sort(True) - - # Create token tot grapheme decomposition graph - tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) - tokens_target.arc_sort() - - # Create alignment graph: - aligments = gtn.project_input( - gtn.remove(gtn.compose(tokens, tokens_target)) - ) - aligments.arc_sort() - - # Add transitions scores: - if transitions is not None: - aligments = gtn.intersect(transitions, aligments) - aligments.arc_sort() - - loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - - # Normalize if needed: - if transitions is not None: - norm = gtn.forward_score(gtn.intersect(emissions, transitions)) - loss = gtn.subtract(loss, norm) - - losses[b] = gtn.negate(loss) - - # Save for backward: - if emissions.calc_grad: - emissions_graphs[b] = emissions - - gtn.parallel_for(process, range(B)) - - ctx.graphs = (losses, emissions_graphs, transitions) - ctx.input_shape = inputs.shape - - # Optionally reduce by target length - if reduction == "mean": - scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] - else: - scales = [1.0] * B - - ctx.scales = scales - - loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) - return torch.mean(loss.to(inputs.device)) - - @staticmethod - def backward(ctx, grad_output) -> Tuple: - losses, emissions_graphs, transitions = ctx.graphs - scales = ctx.scales - - B, T, C = ctx.input_shape - calc_emissions = ctx.needs_input_grad[0] - input_grad = torch.empty((B, T, C)) if calc_emissions else None - - def process(b: int) -> None: - scale = make_scalar_graph(scales[b]) - gtn.backward(losses[b], scale) - emissions = emissions_graphs[b] - if calc_emissions: - grad = emissions.grad().weights_to_numpy() - input_grad[b] = torch.tensor(grad).view(1, T, C) - - gtn.parallel_for(process, range(B)) - - if calc_emissions: - input_grad = input_grad.to(grad_output.device) - input_grad *= grad_output / B - - if ctx.needs_input_grad[4]: - grad = transitions.grad().weights_to_numpy() - transition_grad = torch.tensor(grad).to(grad_output.device) - transition_grad *= grad_output / B - else: - transition_grad = None - - return ( - input_grad, - None, # target - None, # tokens - None, # lexicon - transition_grad, # transition params - None, # transitions graph - None, - ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): - def __init__( - self, - tokens: List, - graphemes_to_idx: Dict, - ngram: int = 0, - transitions: str = None, - blank: str = "none", - allow_repeats: bool = True, - reduction: str = "none", - ) -> None: - """A generic transducer loss function. - - Args: - tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) - representing the output tokens of the model (e.g. letters, - word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] - could be a list of sub-word tokens. - graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. - "a", "b", ..) to their corresponding integer index. - ngram (int) : Order of the token-level transition model. If `ngram=0` - then no transition model is used. - blank (string) : Specifies the usage of blank token - 'none' - do not use blank token - 'optional' - allow an optional blank inbetween tokens - 'forced' - force a blank inbetween tokens (also referred to as garbage token) - allow_repeats (boolean) : If false, then we don't allow paths with - consecutive tokens in the alignment graph. This keeps the graph - unambiguous in the sense that the same input cannot transduce to - different outputs. - """ - super().__init__() - if blank not in ["optional", "forced", "none"]: - raise ValueError( - "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" - ) - self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) - self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) - self.ngram = ngram - if ngram > 0 and transitions is not None: - raise ValueError("Only one of ngram and transitions may be specified") - - if ngram > 0: - transitions = make_transitions_graph( - ngram, len(tokens) + int(blank != "none"), True - ) - - if transitions is not None: - self.transitions = transitions - self.transitions.arc_sort() - self.transitions_params = nn.Parameter( - torch.zeros(self.transitions.num_arcs()) - ) - else: - self.transitions = None - self.transitions_params = None - self.reduction = reduction - - def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - TransducerLoss( - inputs, - targets, - self.tokens, - self.lexicon, - self.transitions_params, - self.transitions, - self.reduction, - ) - - def viterbi(self, outputs: Tensor) -> List[Tensor]: - B, T, C = outputs.shape - - if self.transitions is not None: - cpu_data = self.transition_params.cpu().contiguous() - self.transitions.set_weights(cpu_data.data_ptr()) - self.transitions.calc_grad = False - - self.tokens.arc_sort() - - paths = [None] * B - - def process(b: int) -> None: - emissions = gtn.linear_graph(T, C, False) - cpu_data = outputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - - if self.transitions is not None: - full_graph = gtn.intersect(emissions, self.transitions) - else: - full_graph = emissions - - # Find the best path and remove back-off arcs: - path = gtn.remove(gtn.viterbi_path(full_graph)) - - # Left compose the viterbi path with the "aligment to token" - # transducer to get the outputs: - path = gtn.compose(path, self.tokens) - - # When there are ambiguous paths (allow_repeats is true), we take - # the shortest: - path = gtn.viterbi_path(path) - path = gtn.remove(gtn.project_output(path)) - paths[b] = path.labels_to_list() - - gtn.parallel_for(process, range(B)) - predictions = [torch.IntTensor(path) for path in paths] - return predictions - - -def load_transducer_loss( - num_features: int, - ngram: int, - tokens: str, - lexicon: str, - transitions: str, - blank: str, - allow_repeats: bool, - prepend_wordsep: bool = False, - use_words: bool = False, - data_dir: Optional[Union[str, Path]] = None, - reduction: str = "mean", -) -> Tuple[Transducer, int]: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - - preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, - ) - - num_tokens = preprocessor.num_tokens - - criterion = Transducer( - preprocessor.tokens, - preprocessor.graphemes_to_index, - ngram=ngram, - transitions=transitions, - blank=blank, - allow_repeats=allow_repeats, - reduction=reduction, - ) - - return criterion, num_tokens + int(blank != "none") diff --git a/src/text_recognizer/networks/transformer/__init__.py b/src/text_recognizer/networks/transformer/__init__.py deleted file mode 100644 index 9febc88..0000000 --- a/src/text_recognizer/networks/transformer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transformer modules.""" -from .positional_encoding import PositionalEncoding -from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/src/text_recognizer/networks/transformer/attention.py b/src/text_recognizer/networks/transformer/attention.py deleted file mode 100644 index cce1ecc..0000000 --- a/src/text_recognizer/networks/transformer/attention.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Implementes the attention module for the transformer.""" -from typing import Optional, Tuple - -from einops import rearrange -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class MultiHeadAttention(nn.Module): - """Implementation of multihead attention.""" - - def __init__( - self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.fc_q = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_k = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_v = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - - self._init_weights() - - self.dropout = nn.Dropout(p=dropout_rate) - - def _init_weights(self) -> None: - nn.init.normal_( - self.fc_q.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_k.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_v.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.xavier_normal_(self.fc_out.weight) - - def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tensor: - """Calculates the scaled dot product attention.""" - - # Compute the energy. - energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( - query.shape[-1] - ) - - # If we have a mask for padding some inputs. - if mask is not None: - energy = energy.masked_fill(mask == 0, -np.inf) - - # Compute the attention from the energy. - attention = torch.softmax(energy, dim=3) - - out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) - out = rearrange(out, "b head l v -> b l (head v)") - return out, attention - - def forward( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor]: - """Forward pass for computing the multihead attention.""" - # Get the query, key, and value tensor. - query = rearrange( - self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads - ) - key = rearrange( - self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads - ) - value = rearrange( - self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads - ) - - out, attention = self.scaled_dot_product_attention(query, key, value, mask) - - out = self.fc_out(out) - out = self.dropout(out) - return out, attention diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/src/text_recognizer/networks/transformer/positional_encoding.py deleted file mode 100644 index 1ba5537..0000000 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ /dev/null @@ -1,32 +0,0 @@ -"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence.""" -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class PositionalEncoding(nn.Module): - """Encodes a sense of distance or time for transformer networks.""" - - def __init__( - self, hidden_dim: int, dropout_rate: float, max_len: int = 1000 - ) -> None: - super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) - self.max_len = max_len - - pe = torch.zeros(max_len, hidden_dim) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) - ) - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x: Tensor) -> Tensor: - """Encodes the tensor with a postional embedding.""" - x = x + self.pe[:, : x.shape[1]] - return self.dropout(x) diff --git a/src/text_recognizer/networks/transformer/transformer.py b/src/text_recognizer/networks/transformer/transformer.py deleted file mode 100644 index dd180c4..0000000 --- a/src/text_recognizer/networks/transformer/transformer.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Transfomer module.""" -import copy -from typing import Dict, Optional, Type, Union - -import numpy as np -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - -from text_recognizer.networks.transformer.attention import MultiHeadAttention -from text_recognizer.networks.util import activation_function - - -class GEGLU(nn.Module): - """GLU activation for improving feedforward activations.""" - - def __init__(self, dim_in: int, dim_out: int) -> None: - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x: Tensor) -> Tensor: - """Forward propagation.""" - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList: - return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) - - -class _IntraLayerConnection(nn.Module): - """Preforms the residual connection inside the transfomer blocks and applies layernorm.""" - - def __init__(self, dropout_rate: float, hidden_dim: int) -> None: - super().__init__() - self.norm = nn.LayerNorm(normalized_shape=hidden_dim) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward(self, src: Tensor, residual: Tensor) -> Tensor: - return self.norm(self.dropout(src) + residual) - - -class _ConvolutionalLayer(nn.Module): - def __init__( - self, - hidden_dim: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - in_projection = ( - nn.Sequential( - nn.Linear(hidden_dim, expansion_dim), activation_function(activation) - ) - if activation != "glu" - else GEGLU(hidden_dim, expansion_dim) - ) - - self.layer = nn.Sequential( - in_projection, - nn.Dropout(p=dropout_rate), - nn.Linear(in_features=expansion_dim, out_features=hidden_dim), - ) - - def forward(self, x: Tensor) -> Tensor: - return self.layer(x) - - -class EncoderLayer(nn.Module): - """Transfomer encoding layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through the encoder.""" - # First block. - # Multi head attention. - out, _ = self.self_attention(src, src, src, mask) - - # Add & norm. - out = self.block1(out, src) - - # Second block. - # Apply 1D-convolution. - cnn_out = self.cnn(out) - - # Add & norm. - out = self.block2(cnn_out, out) - - return out - - -class Encoder(nn.Module): - """Transfomer encoder module.""" - - def __init__( - self, - num_layers: int, - encoder_layer: Type[nn.Module], - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.norm = norm - - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: - """Forward pass through all encoder layers.""" - for layer in self.layers: - src = layer(src, src_mask) - - if self.norm is not None: - src = self.norm(src) - - return src - - -class DecoderLayer(nn.Module): - """Transfomer decoder layer.""" - - def __init__( - self, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float = 0.0, - activation: str = "relu", - ) -> None: - super().__init__() - self.hidden_dim = hidden_dim - self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate) - self.multihead_attention = MultiHeadAttention( - hidden_dim, num_heads, dropout_rate - ) - self.cnn = _ConvolutionalLayer( - hidden_dim, expansion_dim, dropout_rate, activation - ) - self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim) - self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim) - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass of the layer.""" - out, _ = self.self_attention(trg, trg, trg, trg_mask) - trg = self.block1(out, trg) - - out, _ = self.multihead_attention(trg, memory, memory, memory_mask) - trg = self.block2(out, trg) - - out = self.cnn(trg) - out = self.block3(out, trg) - - return out - - -class Decoder(nn.Module): - """Transfomer decoder module.""" - - def __init__( - self, - decoder_layer: Type[nn.Module], - num_layers: int, - norm: Optional[Type[nn.Module]] = None, - ) -> None: - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - trg: Tensor, - memory: Tensor, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the decoder.""" - for layer in self.layers: - trg = layer(trg, memory, trg_mask, memory_mask) - - if self.norm is not None: - trg = self.norm(trg) - - return trg - - -class Transformer(nn.Module): - """Transformer network.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - num_heads: int, - expansion_dim: int, - dropout_rate: float, - activation: str = "relu", - ) -> None: - super().__init__() - - # Configure encoder. - encoder_norm = nn.LayerNorm(hidden_dim) - encoder_layer = EncoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm) - - # Configure decoder. - decoder_norm = nn.LayerNorm(hidden_dim) - decoder_layer = DecoderLayer( - hidden_dim, num_heads, expansion_dim, dropout_rate, activation - ) - self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - src: Tensor, - trg: Tensor, - src_mask: Optional[Tensor] = None, - trg_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass through the transformer.""" - if src.shape[0] != trg.shape[0]: - print(trg.shape) - raise RuntimeError("The batch size of the src and trg must be the same.") - if src.shape[2] != trg.shape[2]: - raise RuntimeError( - "The number of features for the src and trg must be the same." - ) - - memory = self.encoder(src, src_mask) - output = self.decoder(trg, memory, trg_mask, memory_mask) - return output diff --git a/src/text_recognizer/networks/unet.py b/src/text_recognizer/networks/unet.py deleted file mode 100644 index 510910f..0000000 --- a/src/text_recognizer/networks/unet.py +++ /dev/null @@ -1,255 +0,0 @@ -"""UNet for segmentation.""" -from typing import List, Optional, Tuple, Union - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -class _ConvBlock(nn.Module): - """Modified UNet convolutional block with dilation.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.channels = channels - self.dropout_rate = dropout_rate - self.kernel_size = kernel_size - self.dilation = dilation - self.padding = padding - self.num_groups = num_groups - self.activation = activation_function(activation) - self.block = self._configure_block() - self.residual_conv = nn.Sequential( - nn.Conv2d( - self.channels[0], self.channels[-1], kernel_size=3, stride=1, padding=1 - ), - self.activation, - ) - - def _configure_block(self) -> nn.Sequential: - block = [] - for i in range(len(self.channels) - 1): - block += [ - nn.Dropout(p=self.dropout_rate), - nn.GroupNorm(self.num_groups, self.channels[i]), - self.activation, - nn.Conv2d( - self.channels[i], - self.channels[i + 1], - kernel_size=self.kernel_size, - padding=self.padding, - stride=1, - dilation=self.dilation, - ), - ] - - return nn.Sequential(*block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the convolutional block.""" - residual = self.residual_conv(x) - return self.block(x) + residual - - -class _DownSamplingBlock(nn.Module): - """Basic down sampling block.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - pooling_kernel: Union[int, bool] = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.down_sampling = nn.MaxPool2d(pooling_kernel) if pooling_kernel else None - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Return the convolutional block output and a down sampled tensor.""" - x = self.conv_block(x) - x_down = self.down_sampling(x) if self.down_sampling is not None else x - - return x_down, x - - -class _UpSamplingBlock(nn.Module): - """The upsampling block of the UNet.""" - - def __init__( - self, - channels: List[int], - activation: str, - num_groups: int, - scale_factor: int = 2, - dropout_rate: float = 0.1, - kernel_size: int = 3, - dilation: int = 1, - padding: int = 0, - ) -> None: - super().__init__() - self.conv_block = _ConvBlock( - channels, - activation, - num_groups, - dropout_rate, - kernel_size, - dilation, - padding, - ) - self.up_sampling = nn.Upsample( - scale_factor=scale_factor, mode="bilinear", align_corners=True - ) - - def forward(self, x: Tensor, x_skip: Optional[Tensor] = None) -> Tensor: - """Apply the up sampling and convolutional block.""" - x = self.up_sampling(x) - if x_skip is not None: - x = torch.cat((x, x_skip), dim=1) - return self.conv_block(x) - - -class UNet(nn.Module): - """UNet architecture.""" - - def __init__( - self, - in_channels: int = 1, - base_channels: int = 64, - num_classes: int = 3, - depth: int = 4, - activation: str = "relu", - num_groups: int = 8, - dropout_rate: float = 0.1, - pooling_kernel: int = 2, - scale_factor: int = 2, - kernel_size: Optional[List[int]] = None, - dilation: Optional[List[int]] = None, - padding: Optional[List[int]] = None, - ) -> None: - super().__init__() - self.depth = depth - self.num_groups = num_groups - - if kernel_size is not None and dilation is not None and padding is not None: - if ( - len(kernel_size) != depth - and len(dilation) != depth - and len(padding) != depth - ): - raise RuntimeError( - "Length of convolutional parameters does not match the depth." - ) - self.kernel_size = kernel_size - self.padding = padding - self.dilation = dilation - - else: - self.kernel_size = [3] * depth - self.padding = [1] * depth - self.dilation = [1] * depth - - self.dropout_rate = dropout_rate - self.conv = nn.Conv2d( - in_channels, base_channels, kernel_size=3, stride=1, padding=1 - ) - - channels = [base_channels] + [base_channels * 2 ** i for i in range(depth)] - self.encoder_blocks = self._configure_down_sampling_blocks( - channels, activation, pooling_kernel - ) - self.decoder_blocks = self._configure_up_sampling_blocks( - channels, activation, scale_factor - ) - - self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1) - - def _configure_down_sampling_blocks( - self, channels: List[int], activation: str, pooling_kernel: int - ) -> nn.ModuleList: - blocks = nn.ModuleList([]) - for i in range(len(channels) - 1): - pooling_kernel = pooling_kernel if i < self.depth - 1 else False - dropout_rate = self.dropout_rate if i < 0 else 0 - blocks += [ - _DownSamplingBlock( - [channels[i], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - pooling_kernel, - dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - ] - - return blocks - - def _configure_up_sampling_blocks( - self, channels: List[int], activation: str, scale_factor: int, - ) -> nn.ModuleList: - channels.reverse() - self.kernel_size.reverse() - self.dilation.reverse() - self.padding.reverse() - return nn.ModuleList( - [ - _UpSamplingBlock( - [channels[i] + channels[i + 1], channels[i + 1], channels[i + 1]], - activation, - self.num_groups, - scale_factor, - self.dropout_rate, - self.kernel_size[i], - self.dilation[i], - self.padding[i], - ) - for i in range(len(channels) - 2) - ] - ) - - def _encode(self, x: Tensor) -> List[Tensor]: - x_skips = [] - for block in self.encoder_blocks: - x, x_skip = block(x) - x_skips.append(x_skip) - return x_skips - - def _decode(self, x_skips: List[Tensor]) -> Tensor: - x = x_skips[-1] - for i, block in enumerate(self.decoder_blocks): - x = block(x, x_skips[-(i + 2)]) - return x - - def forward(self, x: Tensor) -> Tensor: - """Forward pass with the UNet model.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - x = self.conv(x) - x_skips = self._encode(x) - x = self._decode(x_skips) - return self.head(x) diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py deleted file mode 100644 index 131a6b4..0000000 --- a/src/text_recognizer/networks/util.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Miscellaneous neural network functionality.""" -import importlib -from pathlib import Path -from typing import Dict, Tuple, Type - -from einops import rearrange -from loguru import logger -import torch -from torch import nn - - -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the sliding window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], - ) - return patches - - -def activation_function(activation: str) -> Type[nn.Module]: - """Returns the callable activation function.""" - activation_fns = nn.ModuleDict( - [ - ["elu", nn.ELU(inplace=True)], - ["gelu", nn.GELU()], - ["glu", nn.GLU()], - ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], - ["none", nn.Identity()], - ["relu", nn.ReLU(inplace=True)], - ["selu", nn.SELU(inplace=True)], - ] - ) - return activation_fns[activation.lower()] - - -def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: - """Loads a backbone network.""" - network_module = importlib.import_module("text_recognizer.networks") - backbone_ = getattr(network_module, backbone) - - if "pretrained" in backbone_args: - logger.info("Loading pretrained backbone.") - checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( - "pretrained" - ) - - # Loading state directory. - state_dict = torch.load(checkpoint_file) - network_args = state_dict["network_args"] - weights = state_dict["model_state"] - - freeze = False - if "freeze" in backbone_args and backbone_args["freeze"] is True: - backbone_args.pop("freeze") - freeze = True - network_args = backbone_args - - # Initializes the network with trained weights. - backbone = backbone_(**network_args) - backbone.load_state_dict(weights) - if freeze: - for params in backbone.parameters(): - params.requires_grad = False - else: - backbone_ = getattr(network_module, backbone) - backbone = backbone_(**backbone_args) - - if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: - backbone = nn.Sequential( - *list(backbone.children())[:][: -backbone_args["remove_layers"]] - ) - - return backbone diff --git a/src/text_recognizer/networks/vit.py b/src/text_recognizer/networks/vit.py deleted file mode 100644 index efb3701..0000000 --- a/src/text_recognizer/networks/vit.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A Vision Transformer. - -Inspired by: -https://openreview.net/pdf?id=YicbFdNTTy - -""" -from typing import Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import Transformer - - -class ViT(nn.Module): - """Transfomer for image to sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - expansion_dim: int, - patch_dim: Tuple[int, int], - image_size: Tuple[int, int], - dropout_rate: float, - trg_pad_index: int, - max_len: int, - activation: str = "gelu", - ) -> None: - super().__init__() - - self.trg_pad_index = trg_pad_index - self.patch_dim = patch_dim - self.num_patches = image_size[-1] // self.patch_dim[1] - - # Encoder - self.patch_to_embedding = nn.Linear( - self.patch_dim[0] * self.patch_dim[1], hidden_dim - ) - self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.pos_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.dropout = nn.Dropout(dropout_rate) - self._init() - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _init(self) -> None: - nn.init.normal_(self.character_embedding.weight, std=0.02) - # nn.init.normal_(self.pos_embedding.weight, std=0.02) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: A input src to the transformer. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - patches = rearrange( - src, - "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", - p1=self.patch_dim[0], - p2=self.patch_dim[1], - ) - - # From patches to encoded sequence. - x = self.patch_to_embedding(patches) - b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) - x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding[:, : (n + 1)] - x = self.dropout(x) - - return x - - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. - - """ - _, n = trg.shape - trg = self.character_embedding(trg.long()) - trg += self.pos_embedding[:, :n] - return trg - - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) - return logits diff --git a/src/text_recognizer/networks/vq_transformer.py b/src/text_recognizer/networks/vq_transformer.py deleted file mode 100644 index c673d96..0000000 --- a/src/text_recognizer/networks/vq_transformer.py +++ /dev/null @@ -1,150 +0,0 @@ -"""A VQ-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple - -from einops import rearrange, repeat -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class VQTransformer(nn.Module): - """VQ+Transfomer for image to character sequence prediction.""" - - def __init__( - self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - ) -> None: - super().__init__() - - # Configure vector quantized backbone. - self.backbone = configure_backbone(backbone, backbone_args) - self.conv = nn.Sequential( - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2), - nn.ReLU(inplace=True), - ) - - # Configure embeddings for Transformer network. - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None - ) - - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, - ) - - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) - - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) - ) - - def extract_image_features(self, src: Tensor) -> Tuple[Tensor, Tensor]: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D - - Args: - src (Tensor): Input tensor. - - Returns: - Tensor: The input src to the transformer and the vq loss. - - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - src, vq_loss = self.backbone.encode(src) - # src = self.backbone.decoder.res_block(src) - src = self.conv(src) - - if self.adaptive_pool is not None: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - else: - src = rearrange(src, "b c h w -> b (w h) c") - - b, t, _ = src.shape - - src += self.src_position_embedding[:, :t] - - return src, vq_loss - - def target_embedding(self, trg: Tensor) -> Tensor: - """Encodes target tensor with embedding and postion. - - Args: - trg (Tensor): Target tensor. - - Returns: - Tensor: Encoded target tensor. - - """ - trg = self.character_embedding(trg.long()) - trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - - logits = self.head(out) - return logits - - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features, vq_loss = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits, vq_loss diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py deleted file mode 100644 index 763953c..0000000 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""VQ-VAE module.""" -from .decoder import Decoder -from .encoder import Encoder -from .vector_quantizer import VectorQuantizer -from .vqvae import VQVAE diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py deleted file mode 100644 index 8847aba..0000000 --- a/src/text_recognizer/networks/vqvae/decoder.py +++ /dev/null @@ -1,133 +0,0 @@ -"""CNN decoder for the VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.encoder import _ResidualBlock - - -class Decoder(nn.Module): - """A CNN encoder network.""" - - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - upsampling: Optional[List[List[int]]] = None, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.upsampling = upsampling - - self.res_block = nn.ModuleList([]) - self.upsampling_block = nn.ModuleList([]) - - self.embedding_dim = embedding_dim - activation = activation_function(activation) - - # Configure encoder. - self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, - ) - - def _build_decompression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - modules.append( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - activation, - ) - ) - - if i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - modules.extend( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 - ), - nn.Tanh(), - ) - ) - - return modules - - def _build_decoder( - self, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - - self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) - ) - - # Bottleneck module. - self.res_block.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[0], channels[0], dropout) - for i in range(num_residual_layers) - ] - ) - ) - - # Decompression module - self.upsampling_block.extend( - self._build_decompression_block( - channels[0], channels[1:], kernel_sizes, strides, activation, dropout - ) - ) - - self.res_block = nn.Sequential(*self.res_block) - self.upsampling_block = nn.Sequential(*self.upsampling_block) - - return nn.Sequential(self.res_block, self.upsampling_block) - - def forward(self, z_q: Tensor) -> Tensor: - """Reconstruct input from given codes.""" - x_reconstruction = self.decoder(z_q) - return x_reconstruction diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py deleted file mode 100644 index d3adac5..0000000 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ /dev/null @@ -1,147 +0,0 @@ -"""CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) - - -class Encoder(nn.Module): - """A CNN encoder network.""" - - def __init__( - self, - in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) - - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - def _build_compression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules - - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) - - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout - ) - ) - - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] - ) - ) - - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py deleted file mode 100644 index f92c7ee..0000000 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Implementation of a Vector Quantized Variational AutoEncoder. - -Reference: -https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py - -""" - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F - - -class VectorQuantizer(nn.Module): - """The codebook that contains quantized vectors.""" - - def __init__( - self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 - ) -> None: - super().__init__() - self.K = num_embeddings - self.D = embedding_dim - self.beta = beta - - self.embedding = nn.Embedding(self.K, self.D) - - # Initialize the codebook. - nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) - - def discretization_bottleneck(self, latent: Tensor) -> Tensor: - """Computes the code nearest to the latent representation. - - First we compute the posterior categorical distribution, and then map - the latent representation to the nearest element of the embedding. - - Args: - latent (Tensor): The latent representation. - - Shape: - - latent :math:`(B x H x W, D)` - - Returns: - Tensor: The quantized embedding vector. - - """ - # Store latent shape. - b, h, w, d = latent.shape - - # Flatten the latent representation to 2D. - latent = rearrange(latent, "b h w d -> (b h w) d") - - # Compute the L2 distance between the latents and the embeddings. - l2_distance = ( - torch.sum(latent ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) - - 2 * latent @ self.embedding.weight.t() - ) # [BHW x K] - - # Find the embedding k nearest to each latent. - encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] - - # Convert to one-hot encodings, aka discrete bottleneck. - one_hot_encoding = torch.zeros( - encoding_indices.shape[0], self.K, device=latent.device - ) - one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] - - # Embedding quantization. - quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] - quantized_latent = rearrange( - quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w - ) - - return quantized_latent - - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: - """Vector Quantization loss. - - The vector quantization algorithm allows us to create a codebook. The VQ - algorithm works by moving the embedding vectors towards the encoder outputs. - - The embedding loss moves the embedding vector towards the encoder outputs. The - .detach() works as the stop gradient (sg) described in the paper. - - Because the volume of the embedding space is dimensionless, it can arbitarily - grow if the embeddings are not trained as fast as the encoder parameters. To - mitigate this, a commitment loss is added in the second term which makes sure - that the encoder commits to an embedding and that its output does not grow. - - Args: - latent (Tensor): The encoder output. - quantized_latent (Tensor): The quantized latent. - - Returns: - Tensor: The combinded VQ loss. - - """ - embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - return embedding_loss + self.beta * commitment_loss - - def forward(self, latent: Tensor) -> Tensor: - """Forward pass that returns the quantized vector and the vq loss.""" - # Rearrange latent representation s.t. the hidden dim is at the end. - latent = rearrange(latent, "b d h w -> b h w d") - - # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) - - loss = self.vq_loss(latent, quantized_latent) - - # Add residue to the quantized latent. - quantized_latent = latent + (quantized_latent - latent).detach() - - # Rearrange the quantized shape back to the original shape. - quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") - - return quantized_latent, loss diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py deleted file mode 100644 index 50448b4..0000000 --- a/src/text_recognizer/networks/vqvae/vqvae.py +++ /dev/null @@ -1,74 +0,0 @@ -"""The VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.vqvae import Decoder, Encoder - - -class VQVAE(nn.Module): - """Vector Quantized Variational AutoEncoder.""" - - def __init__( - self, - in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - upsampling: Optional[List[List[int]]] = None, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - # configure encoder. - self.encoder = Encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - num_embeddings, - beta, - activation, - dropout_rate, - ) - - # Configure decoder. - channels.reverse() - kernel_sizes.reverse() - strides.reverse() - self.decoder = Decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - upsampling, - activation, - dropout_rate, - ) - - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes input to a latent code.""" - return self.encoder(x) - - def decode(self, z_q: Tensor) -> Tensor: - """Reconstructs input from latent codes.""" - return self.decoder(z_q) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Compresses and decompresses input.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - z_q, vq_loss = self.encode(x) - x_reconstruction = self.decode(z_q) - return x_reconstruction, vq_loss diff --git a/src/text_recognizer/networks/wide_resnet.py b/src/text_recognizer/networks/wide_resnet.py deleted file mode 100644 index b767778..0000000 --- a/src/text_recognizer/networks/wide_resnet.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Wide Residual CNN.""" -from functools import partial -from typing import Callable, Dict, List, Optional, Type, Union - -from einops.layers.torch import Reduce -import numpy as np -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """Helper function for a 3x3 2d convolution.""" - return nn.Conv2d( - in_channels=in_planes, - out_channels=out_planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False, - ) - - -def conv_init(module: Type[nn.Module]) -> None: - """Initializes the weights for convolution and batchnorms.""" - classname = module.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.xavier_uniform_(module.weight, gain=np.sqrt(2)) - nn.init.constant_(module.bias, 0) - elif classname.find("BatchNorm") != -1: - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) - - -class WideBlock(nn.Module): - """Block used in WideResNet.""" - - def __init__( - self, - in_planes: int, - out_planes: int, - dropout_rate: float, - stride: int = 1, - activation: str = "relu", - ) -> None: - super().__init__() - self.in_planes = in_planes - self.out_planes = out_planes - self.dropout_rate = dropout_rate - self.stride = stride - self.activation = activation_function(activation) - - # Build blocks. - self.blocks = nn.Sequential( - nn.BatchNorm2d(self.in_planes), - self.activation, - conv3x3(in_planes=self.in_planes, out_planes=self.out_planes), - nn.Dropout(p=self.dropout_rate), - nn.BatchNorm2d(self.out_planes), - self.activation, - conv3x3( - in_planes=self.out_planes, - out_planes=self.out_planes, - stride=self.stride, - ), - ) - - self.shortcut = ( - nn.Sequential( - nn.Conv2d( - in_channels=self.in_planes, - out_channels=self.out_planes, - kernel_size=1, - stride=self.stride, - bias=False, - ), - ) - if self._apply_shortcut - else None - ) - - @property - def _apply_shortcut(self) -> bool: - """If shortcut should be applied or not.""" - return self.stride != 1 or self.in_planes != self.out_planes - - def forward(self, x: Tensor) -> Tensor: - """Forward pass.""" - residual = x - if self._apply_shortcut: - residual = self.shortcut(x) - x = self.blocks(x) - x += residual - return x - - -class WideResidualNetwork(nn.Module): - """WideResNet for character predictions. - - Can be used for classification or encoding of images to a latent vector. - - """ - - def __init__( - self, - in_channels: int = 1, - in_planes: int = 16, - num_classes: int = 80, - depth: int = 16, - width_factor: int = 10, - dropout_rate: float = 0.0, - num_layers: int = 3, - block: Type[nn.Module] = WideBlock, - num_stages: Optional[List[int]] = None, - activation: str = "relu", - use_decoder: bool = True, - ) -> None: - """The initialization of the WideResNet. - - Args: - in_channels (int): Number of input channels. Defaults to 1. - in_planes (int): Number of channels to use in the first output kernel. Defaults to 16. - num_classes (int): Number of classes. Defaults to 80. - depth (int): Set the number of blocks to use. Defaults to 16. - width_factor (int): Factor for scaling the number of channels in the network. Defaults to 10. - dropout_rate (float): The dropout rate. Defaults to 0.0. - num_layers (int): Number of layers of blocks. Defaults to 3. - block (Type[nn.Module]): The default block is WideBlock. Defaults to WideBlock. - num_stages (List[int]): If given, will use these channel values. Defaults to None. - activation (str): Name of the activation to use. Defaults to "relu". - use_decoder (bool): If True, the network output character predictions, if False, the network outputs a - latent vector. Defaults to True. - - Raises: - RuntimeError: If the depth is not of the size `6n+4`. - - """ - - super().__init__() - if (depth - 4) % 6 != 0: - raise RuntimeError("Wide-resnet depth should be 6n+4") - self.in_channels = in_channels - self.in_planes = in_planes - self.num_classes = num_classes - self.num_blocks = (depth - 4) // 6 - self.width_factor = width_factor - self.num_layers = num_layers - self.block = block - self.dropout_rate = dropout_rate - self.activation = activation_function(activation) - - if num_stages is None: - self.num_stages = [self.in_planes] + [ - self.in_planes * 2 ** n * self.width_factor - for n in range(self.num_layers) - ] - else: - self.num_stages = [self.in_planes] + num_stages - - self.num_stages = list(zip(self.num_stages, self.num_stages[1:])) - self.strides = [1] + [2] * (self.num_layers - 1) - - self.encoder = nn.Sequential( - conv3x3(in_planes=self.in_channels, out_planes=self.in_planes), - *[ - self._configure_wide_layer( - in_planes=in_planes, - out_planes=out_planes, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip( - self.num_stages, self.strides - ) - ], - ) - - self.decoder = ( - nn.Sequential( - nn.BatchNorm2d(self.num_stages[-1][-1], momentum=0.8), - self.activation, - Reduce("b c h w -> b c", "mean"), - nn.Linear( - in_features=self.num_stages[-1][-1], out_features=self.num_classes - ), - ) - if use_decoder - else None - ) - - # self.apply(conv_init) - - def _configure_wide_layer( - self, in_planes: int, out_planes: int, stride: int, activation: str - ) -> List: - strides = [stride] + [1] * (self.num_blocks - 1) - planes = [out_planes] * len(strides) - planes = [(in_planes, out_planes)] + list(zip(planes, planes[1:])) - return nn.Sequential( - *[ - self.block( - in_planes=in_planes, - out_planes=out_planes, - dropout_rate=self.dropout_rate, - stride=stride, - activation=activation, - ) - for (in_planes, out_planes), stride in zip(planes, strides) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Feedforward pass.""" - if len(x.shape) < 4: - x = x[(None,) * int(4 - len(x.shape))] - x = self.encoder(x) - if self.decoder is not None: - x = self.decoder(x) - return x diff --git a/src/text_recognizer/paragraph_text_recognizer.py b/src/text_recognizer/paragraph_text_recognizer.py deleted file mode 100644 index aa39662..0000000 --- a/src/text_recognizer/paragraph_text_recognizer.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Full model. - -Takes an image and returns the text in the image, by first segmenting the image with a LineDetector, then extracting the -each crop of the image corresponding to line regions, and feeding them to a LinePredictor model that outputs the text -in each region. -""" -from typing import Dict, List, Tuple, Union - -import cv2 -import numpy as np -import torch - -from text_recognizer.models import SegmentationModel, TransformerModel -from text_recognizer.util import read_image - - -class ParagraphTextRecognizor: - """Given an image of a single handwritten character, recognizes it.""" - - def __init__(self, line_predictor_args: Dict, line_detector_args: Dict) -> None: - self._line_predictor = TransformerModel(**line_predictor_args) - self._line_detector = SegmentationModel(**line_detector_args) - self._line_detector.eval() - self._line_predictor.eval() - - def predict(self, image_or_filename: Union[str, np.ndarray]) -> Tuple: - """Takes an image and returns all text within it.""" - image = ( - read_image(image_or_filename) - if isinstance(image_or_filename, str) - else image_or_filename - ) - - line_region_crops = self._get_line_region_crops(image) - processed_line_region_crops = [ - self._process_image_for_line_predictor(image=crop) - for crop in line_region_crops - ] - line_region_strings = [ - self.line_predictor_model.predict_on_image(crop)[0] - for crop in processed_line_region_crops - ] - - return " ".join(line_region_strings), line_region_crops - - def _get_line_region_crops( - self, image: np.ndarray, min_crop_len_factor: float = 0.02 - ) -> List[np.ndarray]: - """Returns all the crops of text lines in a square image.""" - processed_image, scale_down_factor = self._process_image_for_line_detector( - image - ) - line_segmentation = self._line_detector.predict_on_image(processed_image) - bounding_boxes = _find_line_bounding_boxes(line_segmentation) - - bounding_boxes = (bounding_boxes * scale_down_factor).astype(int) - - min_crop_len = int(min_crop_len_factor * min(image.shape[0], image.shape[1])) - line_region_crops = [ - image[y : y + h, x : x + w] - for x, y, w, h in bounding_boxes - if w >= min_crop_len and h >= min_crop_len - ] - return line_region_crops - - def _process_image_for_line_detector( - self, image: np.ndarray - ) -> Tuple[np.ndarray, float]: - """Convert uint8 image to float image with black background with shape self._line_detector.image_shape.""" - resized_image, scale_down_factor = _resize_image_for_line_detector( - image=image, max_shape=self._line_detector.image_shape - ) - resized_image = (1.0 - resized_image / 255).astype("float32") - return resized_image, scale_down_factor - - def _process_image_for_line_predictor(self, image: np.ndarray) -> np.ndarray: - """Preprocessing of image before feeding it to the LinePrediction model. - - Convert uint8 image to float image with black background with shape - self._line_predictor.image_shape while maintaining the image aspect ratio. - - Args: - image (np.ndarray): Crop of text line. - - Returns: - np.ndarray: Processed crop for feeding line predictor. - """ - expected_shape = self._line_detector.image_shape - scale_factor = (np.array(expected_shape) / np.array(image.shape)).min() - scaled_image = cv2.resize( - image, - dsize=None, - fx=scale_factor, - fy=scale_factor, - interpolation=cv2.INTER_AREA, - ) - - pad_with = ( - (0, expected_shape[0] - scaled_image.shape[0]), - (0, expected_shape[1] - scaled_image.shape[1]), - ) - - padded_image = np.pad( - scaled_image, pad_with=pad_with, mode="constant", constant_values=255 - ) - return 1 - padded_image / 255 - - -def _find_line_bounding_boxes(line_segmentation: np.ndarray) -> np.ndarray: - """Given a line segmentation, find bounding boxes for connected-component regions corresponding to non-0 labels.""" - - def _find_line_bounding_boxes_in_channel( - line_segmentation_channel: np.ndarray, - ) -> np.ndarray: - line_segmentation_image = cv2.dilate( - line_segmentation_channel, kernel=np.ones((3, 3)), iterations=1 - ) - line_activation_image = (line_segmentation_image * 255).astype("uint8") - line_activation_image = cv2.threshold( - line_activation_image, 0.5, 1, cv2.THRESH_BINARY | cv2.THRESH_OTSU - )[1] - - bounding_cnts, _ = cv2.findContours( - line_segmentation_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) - return np.array([cv2.boundingRect(cnt) for cnt in bounding_cnts]) - - bounding_boxes = np.concatenate( - [ - _find_line_bounding_boxes_in_channel(line_segmentation[:, :, i]) - for i in [1, 2] - ], - axis=0, - ) - - return bounding_boxes[np.argsort(bounding_boxes[:, 1])] - - -def _resize_image_for_line_detector( - image: np.ndarray, max_shape: Tuple[int, int] -) -> Tuple[np.ndarray, float]: - """Resize the image to less than the max_shape while maintaining the aspect ratio.""" - scale_down_factor = max(np.ndarray(image.shape) / np.ndarray(max_shape)) - if scale_down_factor == 1: - return image.copy(), scale_down_factor - resize_image = cv2.resize( - image, - dsize=None, - fx=1 / scale_down_factor, - fy=1 / scale_down_factor, - interpolation=cv2.INTER_AREA, - ) - return resize_image, scale_down_factor diff --git a/src/text_recognizer/tests/__init__.py b/src/text_recognizer/tests/__init__.py deleted file mode 100644 index 18ff212..0000000 --- a/src/text_recognizer/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test modules for the text text recognizer.""" diff --git a/src/text_recognizer/tests/support/__init__.py b/src/text_recognizer/tests/support/__init__.py deleted file mode 100644 index a265ede..0000000 --- a/src/text_recognizer/tests/support/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Support file modules.""" -from .create_emnist_support_files import create_emnist_support_files diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/src/text_recognizer/tests/support/create_emnist_lines_support_files.py deleted file mode 100644 index 9abe143..0000000 --- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Module for creating EMNIST Lines test support files.""" -# flake8: noqa: S106 - -from pathlib import Path -import shutil - -import numpy as np - -from text_recognizer.datasets import EmnistLinesDataset -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist_lines" - - -def create_emnist_lines_support_files() -> None: - """Create EMNIST Lines test images.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - # TODO: maybe have to add args to dataset. - dataset = EmnistLinesDataset( - init_token="<sos>", - pad_token="_", - eos_token="<eos>", - transform=[{"type": "ToTensor", "args": {}}], - target_transform=[ - { - "type": "AddTokens", - "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, - } - ], - ) # nosec: S106 - dataset.load_or_generate_data() - - for index in [5, 7, 9]: - image, target = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - print(image.sum(), image.dtype) - - label = "".join(dataset.mapper(label) for label in target[1:]).strip( - dataset.mapper.pad_token - ) - print(label) - image = image.numpy() - util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_lines_support_files() diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py deleted file mode 100644 index f9ff030..0000000 --- a/src/text_recognizer/tests/support/create_emnist_support_files.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Module for creating EMNIST test support files.""" -from pathlib import Path -import shutil - -from text_recognizer.datasets import EmnistDataset -from text_recognizer.util import write_image - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist" - - -def create_emnist_support_files() -> None: - """Create support images for test of CharacterPredictor class.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - dataset = EmnistDataset(train=False) - dataset.load_or_generate_data() - - for index in [5, 7, 9]: - image, label = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - image = image.numpy() - label = dataset.mapper(int(label)) - print(index, label) - write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_support_files() diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/src/text_recognizer/tests/support/create_iam_lines_support_files.py deleted file mode 100644 index 50f9e3d..0000000 --- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Module for creating IAM Lines test support files.""" -# flake8: noqa -from pathlib import Path -import shutil - -import numpy as np - -from text_recognizer.datasets import IamLinesDataset -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "iam_lines" - - -def create_emnist_lines_support_files() -> None: - """Create IAM Lines test images.""" - shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True) - SUPPORT_DIRNAME.mkdir() - - # TODO: maybe have to add args to dataset. - dataset = IamLinesDataset( - init_token="<sos>", - pad_token="_", - eos_token="<eos>", - transform=[{"type": "ToTensor", "args": {}}], - target_transform=[ - { - "type": "AddTokens", - "args": {"init_token": "<sos>", "pad_token": "_", "eos_token": "<eos>"}, - } - ], - ) - dataset.load_or_generate_data() - - for index in [0, 1, 3]: - image, target = dataset[index] - if len(image.shape) == 3: - image = image.squeeze(0) - print(image.sum(), image.dtype) - - label = "".join(dataset.mapper(label) for label in target[1:]).strip( - dataset.mapper.pad_token - ) - print(label) - image = image.numpy() - util.write_image(image, str(SUPPORT_DIRNAME / f"{label}.png")) - - -if __name__ == "__main__": - create_emnist_lines_support_files() diff --git a/src/text_recognizer/tests/support/emnist/8.png b/src/text_recognizer/tests/support/emnist/8.png Binary files differdeleted file mode 100644 index faa29aa..0000000 --- a/src/text_recognizer/tests/support/emnist/8.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist/U.png b/src/text_recognizer/tests/support/emnist/U.png Binary files differdeleted file mode 100644 index 304eaec..0000000 --- a/src/text_recognizer/tests/support/emnist/U.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist/e.png b/src/text_recognizer/tests/support/emnist/e.png Binary files differdeleted file mode 100644 index a03ecd4..0000000 --- a/src/text_recognizer/tests/support/emnist/e.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png Binary files differdeleted file mode 100644 index b7d0618..0000000 --- a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png Binary files differdeleted file mode 100644 index 14a8cf3..0000000 --- a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png b/src/text_recognizer/tests/support/emnist_lines/they<eos>.png Binary files differdeleted file mode 100644 index 7f05951..0000000 --- a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png Binary files differdeleted file mode 100644 index 6eeb642..0000000 --- a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png Binary files differdeleted file mode 100644 index 4974cf8..0000000 --- a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png Binary files differdeleted file mode 100644 index a731245..0000000 --- a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg Binary files differdeleted file mode 100644 index d9753b6..0000000 --- a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg +++ /dev/null diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py deleted file mode 100644 index 01bda78..0000000 --- a/src/text_recognizer/tests/test_character_predictor.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Test for CharacterPredictor class.""" -import importlib -import os -from pathlib import Path -import unittest - -from loguru import logger - -from text_recognizer.character_predictor import CharacterPredictor -from text_recognizer.networks import MLP - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "emnist" - -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestCharacterPredictor(unittest.TestCase): - """Tests for the CharacterPredictor class.""" - - def test_filename(self) -> None: - """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" - network_fn_ = MLP - predictor = CharacterPredictor(network_fn=network_fn_) - - for filename in SUPPORT_DIRNAME.glob("*.png"): - pred, conf = predictor.predict(str(filename)) - logger.info( - f"Prediction: {pred} at confidence: {conf} for image with character {filename.stem}" - ) - self.assertEqual(pred, filename.stem) - self.assertGreater(conf, 0.7) diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py deleted file mode 100644 index eede4d4..0000000 --- a/src/text_recognizer/tests/test_line_predictor.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Tests for LinePredictor.""" -import os -from pathlib import Path -import unittest - - -import editdistance -import numpy as np - -from text_recognizer.datasets import IamLinesDataset -from text_recognizer.line_predictor import LinePredictor -import text_recognizer.util as util - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" - -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestEmnistLinePredictor(unittest.TestCase): - """Test LinePredictor class on the EmnistLines dataset.""" - - def test_filename(self) -> None: - """Test that LinePredictor correctly predicts on single images, for several test images.""" - predictor = LinePredictor( - dataset="EmnistLineDataset", network_fn="CNNTransformer" - ) - - for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"): - pred, conf = predictor.predict(str(filename)) - true = str(filename.stem) - edit_distance = editdistance.eval(pred, true) / len(pred) - print( - f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}' - ) - self.assertLess(edit_distance, 0.2) diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/src/text_recognizer/tests/test_paragraph_text_recognizer.py deleted file mode 100644 index 3e280b9..0000000 --- a/src/text_recognizer/tests/test_paragraph_text_recognizer.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Test for ParagraphTextRecognizer class.""" -import os -from pathlib import Path -import unittest - -from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor -import text_recognizer.util as util - - -SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" - -# Prevent using GPU. -os.environ["CUDA_VISIBLE_DEVICES"] = "" - - -class TestParagraphTextRecognizor(unittest.TestCase): - """Test that it can take non-square images of max dimension larger than 256px.""" - - def test_filename(self) -> None: - """Test model on support image.""" - line_predictor_args = { - "dataset": "EmnistLineDataset", - "network_fn": "CNNTransformer", - } - line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} - model = ParagraphTextRecognizor( - line_predictor_args=line_predictor_args, - line_detector_args=line_detector_args, - ) - num_text_lines_by_name = {"a01-000u-cropped": 7} - for filename in (SUPPORT_DIRNAME).glob("*.jpg"): - full_image = util.read_image(str(filename), grayscale=True) - predicted_text, line_region_crops = model.predict(full_image) - print(predicted_text) - self.assertTrue( - len(line_region_crops), num_text_lines_by_name[filename.stem] - ) diff --git a/src/text_recognizer/util.py b/src/text_recognizer/util.py deleted file mode 100644 index b431e22..0000000 --- a/src/text_recognizer/util.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Utility functions for text_recognizer module.""" -import os -from pathlib import Path -from typing import Union -from urllib.request import urlopen - -import cv2 -import numpy as np - - -def read_image(image_uri: Union[Path, str], grayscale: bool = False) -> np.ndarray: - """Read image_uri.""" - - def read_image_from_filename(image_filename: str, imread_flag: int) -> np.ndarray: - return cv2.imread(str(image_filename), imread_flag) - - def read_image_from_url(image_url: str, imread_flag: int) -> np.ndarray: - if image_url.lower().startswith("http"): - url_response = urlopen(str(image_url)) - image_array = np.array(bytearray(url_response.read()), dtype=np.uint8) - return cv2.imdecode(image_array, imread_flag) - else: - raise ValueError( - "Url does not start with http, therefore not safe to open..." - ) from None - - imread_flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR - local_file = os.path.exists(image_uri) - image = None - - if local_file: - image = read_image_from_filename(image_uri, imread_flag) - else: - image = read_image_from_url(image_uri, imread_flag) - - if image is None: - raise ValueError(f"Could not load image at {image_uri}") - - return image - - -def rescale_image(image: np.ndarray) -> np.ndarray: - """Rescale image from [0, 1] to [0, 255].""" - if image.max() <= 1.0: - image = 255 * (image - image.min()) / (image.max() - image.min()) - return image - - -def write_image(image: np.ndarray, filename: Union[Path, str]) -> None: - """Write image to file.""" - image = rescale_image(image) - cv2.imwrite(str(filename), image) diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt deleted file mode 100644 index 344e0a3..0000000 --- a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:46d483950ef0876ba072d06cd94021e08d99c4fa14eeccf22aeae1cbb2066b4f -size 5628749 diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt deleted file mode 100644 index f2dfd84..0000000 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8a69e5efedea70c4c5cb8ccdcc8cd480400f6c73e3313423f4dbbfe615644f0a -size 4500617 diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt deleted file mode 100644 index e1add8d..0000000 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:68dd5c98eedc8753546f88b4e6fd5fc38725dc0079b837c30fb3d48069ec412b -size 15002754 diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt Binary files differdeleted file mode 100644 index d9ca01d..0000000 --- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt +++ /dev/null diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt Binary files differdeleted file mode 100644 index 0af0e57..0000000 --- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt +++ /dev/null diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt Binary files differdeleted file mode 100644 index b5295c2..0000000 --- a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt +++ /dev/null |