diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/datasets | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r-- | text_recognizer/datasets/__init__.py | 39 | ||||
-rw-r--r-- | text_recognizer/datasets/dataset.py | 152 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_dataset.py | 131 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_lines_dataset.py | 359 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_dataset.py | 133 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_lines_dataset.py | 110 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_paragraphs_dataset.py | 291 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_preprocessor.py | 196 | ||||
-rw-r--r-- | text_recognizer/datasets/sentence_generator.py | 81 | ||||
-rw-r--r-- | text_recognizer/datasets/transforms.py | 266 | ||||
-rw-r--r-- | text_recognizer/datasets/util.py | 209 |
12 files changed, 1968 insertions, 0 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py new file mode 100644 index 0000000..a6c1c59 --- /dev/null +++ b/text_recognizer/datasets/__init__.py @@ -0,0 +1,39 @@ +"""Dataset modules.""" +from .emnist_dataset import EmnistDataset +from .emnist_lines_dataset import ( + construct_image_from_string, + EmnistLinesDataset, + get_samples_by_character, +) +from .iam_dataset import IamDataset +from .iam_lines_dataset import IamLinesDataset +from .iam_paragraphs_dataset import IamParagraphsDataset +from .iam_preprocessor import load_metadata, Preprocessor +from .transforms import AddTokens, Transpose +from .util import ( + _download_raw_dataset, + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, + ESSENTIALS_FILENAME, +) + +__all__ = [ + "_download_raw_dataset", + "AddTokens", + "compute_sha256", + "construct_image_from_string", + "DATA_DIRNAME", + "download_url", + "EmnistDataset", + "EmnistMapper", + "EmnistLinesDataset", + "get_samples_by_character", + "load_metadata", + "IamDataset", + "IamLinesDataset", + "IamParagraphsDataset", + "Preprocessor", + "Transpose", +] diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py new file mode 100644 index 0000000..e794605 --- /dev/null +++ b/text_recognizer/datasets/dataset.py @@ -0,0 +1,152 @@ +"""Abstract dataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils import data +from torchvision.transforms import ToTensor + +import text_recognizer.datasets.transforms as transforms +from text_recognizer.datasets.util import EmnistMapper + + +class Dataset(data.Dataset): + """Abstract class for with common methods for all datasets.""" + + def __init__( + self, + train: bool, + subsample_fraction: float = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Initialization of Dataset class. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): Only use lower case letters. Defaults to False. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + self.train = train + self.split = "train" if self.train else "test" + + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower + ) + self._input_shape = self._mapper.input_shape + self._output_shape = self._mapper._num_classes + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) + + self._data = None + self._targets = None + + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + raise NotImplementedError + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Raises: + NotImplementedError: If the method is not implemented in child class. + + """ + raise NotImplementedError + + def __repr__(self) -> str: + """Returns information about the dataset.""" + raise NotImplementedError diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py new file mode 100644 index 0000000..9884fdf --- /dev/null +++ b/text_recognizer/datasets/emnist_dataset.py @@ -0,0 +1,131 @@ +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" + +import json +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from torchvision.transforms import Compose, ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.transforms import Transpose +from text_recognizer.datasets.util import DATA_DIRNAME + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + pad_token: str = None, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + pad_token (str): The pad token symbol. Defaults to _. + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + """ + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + pad_token=pad_token, + ) + + self.sample_to_balance = sample_to_balance + + # Have to transpose the emnist characters, ToTensor norms input between [0,1]. + if transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = None + + self.seed = seed + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[Tensor, Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self._data + y = self._targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self._data = x_sampled + self._targets = y_sampled + + def load_or_generate_data(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self._data = dataset.data + self._targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..2a0648a --- /dev/null +++ b/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} diff --git a/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py new file mode 100644 index 0000000..1992446 --- /dev/null +++ b/text_recognizer/datasets/emnist_lines_dataset.py @@ -0,0 +1,359 @@ +"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" + +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose +from text_recognizer.datasets.sentence_generator import SentenceGenerator +from text_recognizer.datasets.util import ( + DATA_DIRNAME, + EmnistMapper, + ESSENTIALS_FILENAME, +) + +DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" + +MAX_WIDTH = 952 + + +class EmnistLinesDataset(Dataset): + """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" + + def __init__( + self, + train: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + subsample_fraction: float = None, + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Set attributes and loads the dataset. + + Args: + train (bool): Flag for the filename. Defaults to False. Defaults to None. + transform (Optional[Callable]): The transform of the data. Defaults to None. + target_transform (Optional[Callable]): The transform of the target. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. + max_length (int): The maximum number of characters. Defaults to 34. + min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. + max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. + num_samples (int): Number of samples to generate. Defaults to 10000. + seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. + + """ + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + transform=transform, + target_transform=target_transform, + subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=self.pad_token, + eos_token=eos_token, + lower=lower, + ) + + # Extract dataset information. + self._input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes + + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_samples = num_samples + self._input_shape = ( + self.input_shape[0], + self.input_shape[1] * self.max_length, + ) + self._output_shape = (self.max_length, self.num_classes) + self.seed = seed + + # Placeholders for the dataset. + self._data = None + self._target = None + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Lines Dataset\n" # pylint: disable=no-member + f"Max length: {self.max_length}\n" + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Data: {self.data.shape}\n" + f"Tagets: {self.targets.shape}\n" + ) + + @property + def data_filename(self) -> Path: + """Path to the h5 file.""" + filename = "train.pt" if self.train else "test.pt" + return DATA_DIRNAME / filename + + def load_or_generate_data(self) -> None: + """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" + np.random.seed(self.seed) + + if not self.data_filename.exists(): + self._generate_data() + self._load_data() + self._subsample() + + def _load_data(self) -> None: + """Loads the dataset from the h5 file.""" + logger.debug("EmnistLinesDataset loading data from HDF5...") + with h5py.File(self.data_filename, "r") as f: + self._data = f["data"][()] + self._targets = f["targets"][()] + + def _generate_data(self) -> str: + """Generates a dataset with the Brown corpus and Emnist characters.""" + logger.debug("Generating data...") + + sentence_generator = SentenceGenerator(self.max_length) + + # Load emnist dataset. + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) + emnist.load_or_generate_data() + + samples_by_character = get_samples_by_character( + emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, + ) + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + data, targets = create_dataset_of_images( + self.num_samples, + samples_by_character, + sentence_generator, + self.min_overlap, + self.max_overlap, + ) + + targets = convert_strings_to_categorical_labels( + targets, emnist.inverse_mapping + ) + + f.create_dataset("data", data=data, dtype="u1", compression="lzf") + f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") + + +def get_samples_by_character( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + """Creates a dictionary with character as key and value as the list of images of that character. + + Args: + samples (np.ndarray): Dataset of images of characters. + labels (np.ndarray): The labels for each image. + mapping (Dict): The Emnist mapping dictionary. + + Returns: + defaultdict: A dictionary with characters as keys and list of images as values. + + """ + samples_by_character = defaultdict(list) + for sample, label in zip(samples, labels.flatten()): + samples_by_character[mapping[label]].append(sample) + return samples_by_character + + +def select_letter_samples_for_string( + string: str, samples_by_character: Dict +) -> List[np.ndarray]: + """Randomly selects Emnist characters to use for the senetence. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + + Returns: + List[np.ndarray]: A list of emnist images of the string. + + """ + zero_image = np.zeros((28, 28), np.uint8) + sample_image_by_character = {} + for character in string: + if character in sample_image_by_character: + continue + samples = samples_by_character[character] + sample = samples[np.random.choice(len(samples))] if samples else zero_image + sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) + return [sample_image_by_character[character] for character in string] + + +def construct_image_from_string( + string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float +) -> np.ndarray: + """Concatenates images of the characters in the string. + + The concatination is made with randomly selected overlap so that some portion of the character will overlap. + + Args: + string (str): The word or sentence. + samples_by_character (Dict): The dictionary of emnist images of each character. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + np.ndarray: The Emnist image of the string. + + """ + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = select_letter_samples_for_string(string, samples_by_character) + length = len(sampled_images) + height, width = sampled_images[0].shape + next_overlap_width = width - int(overlap * width) + concatenated_image = np.zeros((height, width * length), np.uint8) + x = 0 + for image in sampled_images: + concatenated_image[:, x : (x + width)] += image + x += next_overlap_width + + if concatenated_image.shape[-1] > MAX_WIDTH: + concatenated_image = Tensor(concatenated_image).unsqueeze(0) + concatenated_image = F.interpolate( + concatenated_image, size=MAX_WIDTH, mode="nearest" + ) + concatenated_image = concatenated_image.squeeze(0).numpy() + + return np.minimum(255, concatenated_image) + + +def create_dataset_of_images( + length: int, + samples_by_character: Dict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, +) -> Tuple[np.ndarray, List[str]]: + """Creates a dataset with images and labels from strings generated from the SentenceGenerator. + + Args: + length (int): The number of characters for each string. + samples_by_character (Dict): The dictionary of emnist images of each character. + sentence_generator (SentenceGenerator): A SentenceGenerator objest. + min_overlap (float): Minimum amount of overlap between Emnist images. + max_overlap (float): Maximum amount of overlap between Emnist images. + + Returns: + Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). + + Raises: + RuntimeError: If the sentence generator is not able to generate a string. + + """ + sample_label = sentence_generator.generate() + sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) + images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) + labels = [] + for n in range(length): + label = None + # Try several times to generate before actually throwing an error. + for _ in range(10): + try: + label = sentence_generator.generate() + break + except Exception: # pylint: disable=broad-except + pass + if label is None: + raise RuntimeError("Was not able to generate a valid string.") + images[n] = construct_image_from_string( + label, samples_by_character, min_overlap, max_overlap + ) + labels.append(label) + return images, labels + + +def convert_strings_to_categorical_labels( + labels: List[str], mapping: Dict +) -> np.ndarray: + """Translates a string of characters in to a target array of class int.""" + return np.array([[mapping[c] for c in label] for label in labels]) + + +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") +def create_datasets( + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_train: int = 10000, + num_test: int = 1000, +) -> None: + """Creates a training an validation dataset of Emnist lines.""" + num_samples = [num_train, num_test] + for num, train in zip(num_samples, [True, False]): + emnist_lines = EmnistLinesDataset( + train=train, + max_length=max_length, + min_overlap=min_overlap, + max_overlap=max_overlap, + num_samples=num, + ) + emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py new file mode 100644 index 0000000..a8998b9 --- /dev/null +++ b/text_recognizer/datasets/iam_dataset.py @@ -0,0 +1,133 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml + +from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" +RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset: + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py new file mode 100644 index 0000000..1cb84bd --- /dev/null +++ b/text_recognizer/datasets/iam_lines_dataset.py @@ -0,0 +1,110 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, + lower=lower, + ) + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] if self.data is not None else None + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] + self._subsample() + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py new file mode 100644 index 0000000..8ba5142 --- /dev/null +++ b/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -0,0 +1,291 @@ +"""IamParagraphsDataset class and functions for data processing.""" +import random +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self._ids = None + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + seed = np.random.randint(SEED) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.transform: + data = self.transform(data) + + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets.long() + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=None, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py new file mode 100644 index 0000000..a93eb00 --- /dev/null +++ b/text_recognizer/datasets/iam_preprocessor.py @@ -0,0 +1,196 @@ +"""Preprocessor for extracting word letters from the IAM dataset. + +The code is mostly stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + +""" + +import collections +import itertools +from pathlib import Path +import re +from typing import List, Optional, Union + +import click +from loguru import logger +import torch + + +def load_metadata( + data_dir: Path, wordsep: str, use_words: bool = False +) -> collections.defaultdict: + """Loads IAM metadata and returns it as a dictionary.""" + forms = collections.defaultdict(list) + filename = "words.txt" if use_words else "lines.txt" + + with open(data_dir / "ascii" / filename, "r") as f: + lines = (line.strip().split() for line in f if line[0] != "#") + for line in lines: + # Skip word segmentation errors. + if use_words and line[1] == "err": + continue + text = " ".join(line[8:]) + + # Remove garbage tokens: + text = text.replace("#", "") + + # Swap word sep form | to wordsep + text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) + form_key = "-".join(line[0].split("-")[:2]) + line_key = "-".join(line[0].split("-")[:3]) + box_idx = 4 - use_words + box = tuple(int(val) for val in line[box_idx : box_idx + 4]) + forms[form_key].append({"key": line_key, "box": box, "text": text}) + return forms + + +class Preprocessor: + """A preprocessor for the IAM dataset.""" + + # TODO: add lower case only to when generating... + + def __init__( + self, + data_dir: Union[str, Path], + num_features: int, + tokens_path: Optional[Union[str, Path]] = None, + lexicon_path: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + self.wordsep = "▁" + self._use_word = use_words + self._prepend_wordsep = prepend_wordsep + + self.data_dir = Path(data_dir) + + self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) + + # Load the set of graphemes: + graphemes = set() + for _, form in self.forms.items(): + for line in form: + graphemes.update(line["text"].lower()) + self.graphemes = sorted(graphemes) + + # Build the token-to-index and index-to-token maps. + if tokens_path is not None: + with open(tokens_path, "r") as f: + self.tokens = [line.strip() for line in f] + else: + self.tokens = self.graphemes + + if lexicon_path is not None: + with open(lexicon_path, "r") as f: + lexicon = (line.strip().split() for line in f) + lexicon = {line[0]: line[1:] for line in lexicon} + self.lexicon = lexicon + else: + self.lexicon = None + + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} + self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} + self.num_features = num_features + self.text = [] + + @property + def num_tokens(self) -> int: + """Returns the number or tokens.""" + return len(self.tokens) + + @property + def use_words(self) -> bool: + """If words are used.""" + return self._use_word + + def extract_train_text(self) -> None: + """Extracts training text.""" + keys = [] + with open(self.data_dir / "task" / "trainset.txt") as f: + keys.extend((line.strip() for line in f)) + + for _, examples in self.forms.items(): + for example in examples: + if example["key"] not in keys: + continue + self.text.append(example["text"].lower()) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + token_to_index = self.graphemes_to_index + if self.lexicon is not None: + if len(line) > 0: + # If the word is not found in the lexicon, fall back to letters. + line = [ + t + for w in line.split(self.wordsep) + for t in self.lexicon.get(w, self.wordsep + w) + ] + token_to_index = self.tokens_to_index + if self._prepend_wordsep: + line = itertools.chain([self.wordsep], line) + return torch.LongTensor([token_to_index[t] for t in line]) + + def to_text(self, indices: List[int]) -> str: + """Converts indices to text.""" + # Roughly the inverse of `to_index` + encoding = self.graphemes + if self.lexicon is not None: + encoding = self.tokens + return self._post_process(encoding[i] for i in indices) + + def tokens_to_text(self, indices: List[int]) -> str: + """Converts tokens to text.""" + return self._post_process(self.tokens[i] for i in indices) + + def _post_process(self, indices: List[int]) -> str: + """A list join.""" + return "".join(indices).strip(self.wordsep) + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") +@click.option( + "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" +) +@click.option( + "--save_text", type=str, default=None, help="Path to save parsed train text" +) +@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") +def cli( + data_dir: Optional[str], + use_words: bool, + save_text: Optional[str], + save_tokens: Optional[str], +) -> None: + """CLI for extracting text data from the iam dataset.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + preprocessor = Preprocessor(data_dir, 64, use_words=use_words) + preprocessor.extract_train_text() + + processed_dir = data_dir.parents[2] / "processed" / "iam_lines" + logger.debug(f"Saving processed files at: {processed_dir}") + + if save_text is not None: + logger.info("Saving training text") + with open(processed_dir / save_text, "w") as f: + f.write("\n".join(t for t in preprocessor.text)) + + if save_tokens is not None: + logger.info("Saving tokens") + with open(processed_dir / save_tokens, "w") as f: + f.write("\n".join(preprocessor.tokens)) + + +if __name__ == "__main__": + cli() diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py new file mode 100644 index 0000000..dd76652 --- /dev/null +++ b/text_recognizer/datasets/sentence_generator.py @@ -0,0 +1,81 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets.util import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + padding = "_" * (max_length - len(sampled_text)) + return sampled_text + padding + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py new file mode 100644 index 0000000..b6a48f5 --- /dev/null +++ b/text_recognizer/datasets/transforms.py @@ -0,0 +1,266 @@ +"""Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path +import random +from typing import Any, Optional, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) + +from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.datasets.util import EmnistMapper + + +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +class Resize: + """Resizes a tensor to a specified width.""" + + def __init__(self, width: int = 952) -> None: + # The default is 952 because of the IAM dataset. + self.width = width + + def __call__(self, image: Tensor) -> Tensor: + """Resize tensor in the last dimension.""" + return F.interpolate(image, size=self.width, mode="nearest") + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target + + +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high + + def __call__(self, x: Tensor) -> Tensor: + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask + + +class Unsqueeze: + """Add a dimension to the tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Adds dim.""" + return x.unsqueeze(0) + + +class Squeeze: + """Removes the first dimension of a tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Removes first dim.""" + return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py new file mode 100644 index 0000000..da87756 --- /dev/null +++ b/text_recognizer/datasets/util.py @@ -0,0 +1,209 @@ +"""Util functions for datasets.""" +import hashlib +import json +import os +from pathlib import Path +import string +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from tqdm import tqdm + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) + + +def download_emnist() -> None: + """Download the EMNIST dataset via the PyTorch class.""" + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) + + +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.lower = lower + + self.essentials = self._load_emnist_essentials() + # Load dataset information. + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self) -> None: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol, and acts as blank symbol as well. + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) + + max_key = max(self.mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + self._mapping = {**self.mapping, **extra_mapping} + + +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) |