diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/datasets | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 39 | ||||
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 152 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 131 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_essentials.json | 1 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 359 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_dataset.py | 132 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 110 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 291 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_preprocessor.py | 196 | ||||
-rw-r--r-- | src/text_recognizer/datasets/sentence_generator.py | 81 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 266 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 209 |
12 files changed, 0 insertions, 1967 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py deleted file mode 100644 index a6c1c59..0000000 --- a/src/text_recognizer/datasets/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Dataset modules.""" -from .emnist_dataset import EmnistDataset -from .emnist_lines_dataset import ( - construct_image_from_string, - EmnistLinesDataset, - get_samples_by_character, -) -from .iam_dataset import IamDataset -from .iam_lines_dataset import IamLinesDataset -from .iam_paragraphs_dataset import IamParagraphsDataset -from .iam_preprocessor import load_metadata, Preprocessor -from .transforms import AddTokens, Transpose -from .util import ( - _download_raw_dataset, - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -__all__ = [ - "_download_raw_dataset", - "AddTokens", - "compute_sha256", - "construct_image_from_string", - "DATA_DIRNAME", - "download_url", - "EmnistDataset", - "EmnistMapper", - "EmnistLinesDataset", - "get_samples_by_character", - "load_metadata", - "IamDataset", - "IamLinesDataset", - "IamParagraphsDataset", - "Preprocessor", - "Transpose", -] diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/src/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py deleted file mode 100644 index 9884fdf..0000000 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" - -import json -from pathlib import Path -from typing import Callable, Optional, Tuple, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.transforms import Transpose -from text_recognizer.datasets.util import DATA_DIRNAME - - -class EmnistDataset(Dataset): - """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" - - def __init__( - self, - pad_token: str = None, - train: bool = False, - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - seed: int = 4711, - ) -> None: - """Loads the dataset and the mappings. - - Args: - pad_token (str): The pad token symbol. Defaults to _. - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. - seed (int): Seed number. Defaults to 4711. - - """ - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - pad_token=pad_token, - ) - - self.sample_to_balance = sample_to_balance - - # Have to transpose the emnist characters, ToTensor norms input between [0,1]. - if transform is None: - self.transform = Compose([Transpose(), ToTensor()]) - - self.target_transform = None - - self.seed = seed - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, Tensor]): The indices of the samples to fetch. - - Returns: - Tuple[Tensor, Tensor]: Data target tuple. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return ( - "EMNIST Dataset\n" - f"Num classes: {self.num_classes}\n" - f"Input shape: {self.input_shape}\n" - f"Mapping: {self.mapper.mapping}\n" - ) - - def _sample_to_balance(self) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = self._data - y = self._targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - self._data = x_sampled - self._targets = y_sampled - - def load_or_generate_data(self) -> None: - """Fetch the EMNIST dataset.""" - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=self.train, - download=False, - transform=None, - target_transform=None, - ) - - self._data = dataset.data - self._targets = dataset.targets - - if self.sample_to_balance: - self._sample_to_balance() - - if self.subsample_fraction is not None: - self._subsample() diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json deleted file mode 100644 index 2a0648a..0000000 --- a/src/text_recognizer/datasets/emnist_essentials.json +++ /dev/null @@ -1 +0,0 @@ -{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py deleted file mode 100644 index 1992446..0000000 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters.""" - -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose -from text_recognizer.datasets.sentence_generator import SentenceGenerator -from text_recognizer.datasets.util import ( - DATA_DIRNAME, - EmnistMapper, - ESSENTIALS_FILENAME, -) - -DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" - -MAX_WIDTH = 952 - - -class EmnistLinesDataset(Dataset): - """Synthetic dataset of lines from the Brown corpus with Emnist characters.""" - - def __init__( - self, - train: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - subsample_fraction: float = None, - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_samples: int = 10000, - seed: int = 4711, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Set attributes and loads the dataset. - - Args: - train (bool): Flag for the filename. Defaults to False. Defaults to None. - transform (Optional[Callable]): The transform of the data. Defaults to None. - target_transform (Optional[Callable]): The transform of the target. Defaults to None. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - max_length (int): The maximum number of characters. Defaults to 34. - min_overlap (float): The minimum overlap between concatenated images. Defaults to 0. - max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. - num_samples (int): Number of samples to generate. Defaults to 10000. - seed (int): Seed number. Defaults to 4711. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. - - """ - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - transform=transform, - target_transform=target_transform, - subsample_fraction=subsample_fraction, - init_token=init_token, - pad_token=self.pad_token, - eos_token=eos_token, - lower=lower, - ) - - # Extract dataset information. - self._input_shape = self._mapper.input_shape - self.num_classes = self._mapper.num_classes - - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_samples = num_samples - self._input_shape = ( - self.input_shape[0], - self.input_shape[1] * self.max_length, - ) - self._output_shape = (self.max_length, self.num_classes) - self.seed = seed - - # Placeholders for the dataset. - self._data = None - self._target = None - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return ( - "EMNIST Lines Dataset\n" # pylint: disable=no-member - f"Max length: {self.max_length}\n" - f"Min overlap: {self.min_overlap}\n" - f"Max overlap: {self.max_overlap}\n" - f"Num classes: {self.num_classes}\n" - f"Input shape: {self.input_shape}\n" - f"Data: {self.data.shape}\n" - f"Tagets: {self.targets.shape}\n" - ) - - @property - def data_filename(self) -> Path: - """Path to the h5 file.""" - filename = "train.pt" if self.train else "test.pt" - return DATA_DIRNAME / filename - - def load_or_generate_data(self) -> None: - """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" - np.random.seed(self.seed) - - if not self.data_filename.exists(): - self._generate_data() - self._load_data() - self._subsample() - - def _load_data(self) -> None: - """Loads the dataset from the h5 file.""" - logger.debug("EmnistLinesDataset loading data from HDF5...") - with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][()] - self._targets = f["targets"][()] - - def _generate_data(self) -> str: - """Generates a dataset with the Brown corpus and Emnist characters.""" - logger.debug("Generating data...") - - sentence_generator = SentenceGenerator(self.max_length) - - # Load emnist dataset. - emnist = EmnistDataset( - train=self.train, sample_to_balance=True, pad_token=self.pad_token - ) - emnist.load_or_generate_data() - - samples_by_character = get_samples_by_character( - emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, - ) - - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "a") as f: - data, targets = create_dataset_of_images( - self.num_samples, - samples_by_character, - sentence_generator, - self.min_overlap, - self.max_overlap, - ) - - targets = convert_strings_to_categorical_labels( - targets, emnist.inverse_mapping - ) - - f.create_dataset("data", data=data, dtype="u1", compression="lzf") - f.create_dataset("targets", data=targets, dtype="u1", compression="lzf") - - -def get_samples_by_character( - samples: np.ndarray, labels: np.ndarray, mapping: Dict -) -> defaultdict: - """Creates a dictionary with character as key and value as the list of images of that character. - - Args: - samples (np.ndarray): Dataset of images of characters. - labels (np.ndarray): The labels for each image. - mapping (Dict): The Emnist mapping dictionary. - - Returns: - defaultdict: A dictionary with characters as keys and list of images as values. - - """ - samples_by_character = defaultdict(list) - for sample, label in zip(samples, labels.flatten()): - samples_by_character[mapping[label]].append(sample) - return samples_by_character - - -def select_letter_samples_for_string( - string: str, samples_by_character: Dict -) -> List[np.ndarray]: - """Randomly selects Emnist characters to use for the senetence. - - Args: - string (str): The word or sentence. - samples_by_character (Dict): The dictionary of emnist images of each character. - - Returns: - List[np.ndarray]: A list of emnist images of the string. - - """ - zero_image = np.zeros((28, 28), np.uint8) - sample_image_by_character = {} - for character in string: - if character in sample_image_by_character: - continue - samples = samples_by_character[character] - sample = samples[np.random.choice(len(samples))] if samples else zero_image - sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1) - return [sample_image_by_character[character] for character in string] - - -def construct_image_from_string( - string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float -) -> np.ndarray: - """Concatenates images of the characters in the string. - - The concatination is made with randomly selected overlap so that some portion of the character will overlap. - - Args: - string (str): The word or sentence. - samples_by_character (Dict): The dictionary of emnist images of each character. - min_overlap (float): Minimum amount of overlap between Emnist images. - max_overlap (float): Maximum amount of overlap between Emnist images. - - Returns: - np.ndarray: The Emnist image of the string. - - """ - overlap = np.random.uniform(min_overlap, max_overlap) - sampled_images = select_letter_samples_for_string(string, samples_by_character) - length = len(sampled_images) - height, width = sampled_images[0].shape - next_overlap_width = width - int(overlap * width) - concatenated_image = np.zeros((height, width * length), np.uint8) - x = 0 - for image in sampled_images: - concatenated_image[:, x : (x + width)] += image - x += next_overlap_width - - if concatenated_image.shape[-1] > MAX_WIDTH: - concatenated_image = Tensor(concatenated_image).unsqueeze(0) - concatenated_image = F.interpolate( - concatenated_image, size=MAX_WIDTH, mode="nearest" - ) - concatenated_image = concatenated_image.squeeze(0).numpy() - - return np.minimum(255, concatenated_image) - - -def create_dataset_of_images( - length: int, - samples_by_character: Dict, - sentence_generator: SentenceGenerator, - min_overlap: float, - max_overlap: float, -) -> Tuple[np.ndarray, List[str]]: - """Creates a dataset with images and labels from strings generated from the SentenceGenerator. - - Args: - length (int): The number of characters for each string. - samples_by_character (Dict): The dictionary of emnist images of each character. - sentence_generator (SentenceGenerator): A SentenceGenerator objest. - min_overlap (float): Minimum amount of overlap between Emnist images. - max_overlap (float): Maximum amount of overlap between Emnist images. - - Returns: - Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels). - - Raises: - RuntimeError: If the sentence generator is not able to generate a string. - - """ - sample_label = sentence_generator.generate() - sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0) - images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8) - labels = [] - for n in range(length): - label = None - # Try several times to generate before actually throwing an error. - for _ in range(10): - try: - label = sentence_generator.generate() - break - except Exception: # pylint: disable=broad-except - pass - if label is None: - raise RuntimeError("Was not able to generate a valid string.") - images[n] = construct_image_from_string( - label, samples_by_character, min_overlap, max_overlap - ) - labels.append(label) - return images, labels - - -def convert_strings_to_categorical_labels( - labels: List[str], mapping: Dict -) -> np.ndarray: - """Translates a string of characters in to a target array of class int.""" - return np.array([[mapping[c] for c in label] for label in labels]) - - -@click.command() -@click.option( - "--max_length", type=int, default=34, help="Number of characters in a sentence." -) -@click.option( - "--min_overlap", type=float, default=0.0, help="Min overlap between characters." -) -@click.option( - "--max_overlap", type=float, default=0.33, help="Max overlap between characters." -) -@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") -@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") -def create_datasets( - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_train: int = 10000, - num_test: int = 1000, -) -> None: - """Creates a training an validation dataset of Emnist lines.""" - num_samples = [num_train, num_test] - for num, train in zip(num_samples, [True, False]): - emnist_lines = EmnistLinesDataset( - train=train, - max_length=max_length, - min_overlap=min_overlap, - max_overlap=max_overlap, - num_samples=num, - ) - emnist_lines.load_or_generate_data() - - -if __name__ == "__main__": - create_datasets() diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py deleted file mode 100644 index f4a869d..0000000 --- a/src/text_recognizer/datasets/iam_dataset.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. - - -class IamDataset: - """IAM dataset. - - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - - The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. - The validation set has been merged into the train set. - The train set has 7,101 lines from 326 writers. - The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - - """ - - def __init__(self) -> None: - self.metadata = toml.load(METADATA_FILENAME) - - def load_or_generate_data(self) -> None: - """Downloads IAM dataset if xml files does not exist.""" - if not self.xml_filenames: - self._download_iam() - - @property - def xml_filenames(self) -> List: - """List of xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - - @property - def form_filenames(self) -> List: - """List of forms filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - - def _download_iam(self) -> None: - curdir = os.getcwd() - os.chdir(RAW_DATA_DIRNAME) - _download_raw_dataset(self.metadata) - _extract_raw_dataset(self.metadata) - os.chdir(curdir) - - @property - def form_filenames_by_id(self) -> Dict: - """Creates a dictionary with filenames as keys and forms as values.""" - return {filename.stem: filename for filename in self.form_filenames} - - @cachedproperty - def line_strings_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of line texts in it.""" - return { - filename.stem: _get_line_strings_from_xml_file(filename) - for filename in self.xml_filenames - } - - @cachedproperty - def line_regions_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return { - filename.stem: _get_line_regions_from_xml_file(filename) - for filename in self.xml_filenames - } - - def __repr__(self) -> str: - """Print info about dataset.""" - return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: - logger.info("Extracting IAM data.") - with zipfile.ZipFile(metadata["filename"], "r") as zip_file: - zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: - """Get the text content of each line. Note that we replace " with ".""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: - """Get the line region dict for each line.""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: - """Extracts coordinates for each line of text.""" - # TODO: fix input! - word_elements = xml_line.findall("word/cmp") - x1s = [int(el.attrib["x"]) for el in word_elements] - y1s = [int(el.attrib["y"]) for el in word_elements] - x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] - return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } - - -def main() -> None: - """Initializes the dataset and print info about the dataset.""" - dataset = IamDataset() - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( - "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): - """IAM lines datasets for handwritten text lines.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - init_token=init_token, - pad_token=pad_token, - eos_token=eos_token, - lower=lower, - ) - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self.data.shape[1:] if self.data is not None else None - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return ( - self.targets.shape[1:] + (self.num_classes,) - if self.targets is not None - else None - ) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - if not PROCESSED_DATA_FILENAME.exists(): - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info("Downloading IAM lines...") - download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self._data = f[f"x_{self.split}"][:] - self._targets = f[f"y_{self.split}"][:] - self._subsample() - - def __repr__(self) -> str: - """Print info about the dataset.""" - return ( - "IAM Lines Dataset\n" # pylint: disable=no-member - f"Number classes: {self.num_classes}\n" - f"Mapping: {self.mapper.mapping}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - ) - # Load Iam dataset. - self.iam_dataset = IamDataset() - - self.num_classes = 3 - self._input_shape = (256, 256) - self._output_shape = self._input_shape + (self.num_classes,) - self._ids = None - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - seed = np.random.randint(SEED) - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.transform: - data = self.transform(data) - - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets.long() - - @property - def ids(self) -> Tensor: - """Ids of the dataset.""" - return self._ids - - def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: - """Get data target pair from id.""" - ind = self.ids.index(id_) - return self.data[ind], self.targets[ind] - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) - num_targets = len(self.iam_dataset.line_regions_by_id) - - if num_actual < num_targets - 2: - self._process_iam_paragraphs() - - self._data, self._targets, self._ids = _load_iam_paragraphs() - self._get_random_split() - self._subsample() - - def _get_random_split(self) -> None: - np.random.seed(SEED) - num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) - indices = np.random.permutation(self.data.shape[0]) - train_indices, test_indices = indices[:num_train], indices[num_train:] - if self.train: - self._data = self.data[train_indices] - self._targets = self.targets[train_indices] - else: - self._data = self.data[test_indices] - self._targets = self.targets[test_indices] - - def _process_iam_paragraphs(self) -> None: - """Crop the part with the text. - - For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are - self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel - corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line - """ - crop_dims = self._decide_on_crop_dims() - CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - GT_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info( - f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" - ) - for filename in self.iam_dataset.form_filenames: - id_ = filename.stem - line_region = self.iam_dataset.line_regions_by_id[id_] - _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - - def _decide_on_crop_dims(self) -> Tuple[int, int]: - """Decide on the dimensions to crop out of the form image. - - Since image width is larger than a comfortable crop around the longest paragraph, - we will make the crop a square form factor. - And since the found dimensions 610x610 are pretty close to 512x512, - we might as well resize crops and make it exactly that, which lets us - do all kinds of power-of-2 pooling and upsampling should we choose to. - - Returns: - Tuple[int, int]: A tuple of crop dimensions. - - Raises: - RuntimeError: When max crop height is larger than max crop width. - - """ - - sample_form_filename = self.iam_dataset.form_filenames[0] - sample_image = util.read_image(sample_form_filename, grayscale=True) - max_crop_width = sample_image.shape[1] - max_crop_height = _get_max_paragraph_crop_height( - self.iam_dataset.line_regions_by_id - ) - if not max_crop_height <= max_crop_width: - raise RuntimeError( - f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" - ) - - crop_dims = (max_crop_width, max_crop_width) - logger.info( - f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." - ) - logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") - return crop_dims - - def __repr__(self) -> str: - """Return info about the dataset.""" - return ( - "IAM Paragraph Dataset\n" # pylint: disable=no-member - f"Num classes: {self.num_classes}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: - heights = [] - for regions in line_regions_by_id.values(): - min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - heights.append(height) - return max(heights) - - -def _crop_paragraph_image( - filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: - image = util.read_image(filename, grayscale=True) - - min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - crop_height = crop_dims[0] - buffer = (crop_height - height) // 2 - - # Generate image crop. - image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) - try: - image_crop[buffer : buffer + height] = image[min_y1:max_y2] - except Exception as e: # pylint: disable=broad-except - logger.error(f"Rescued {filename}: {e}") - return - - # Generate ground truth. - gt_image = np.zeros_like(image_crop, dtype=np.uint8) - for index, region in enumerate(line_regions): - gt_image[ - (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), - region["x1"] : region["x2"], - ] = (index % 2 + 1) - - # Generate image for debugging. - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("Set1") - image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) - for index, region in enumerate(line_regions): - color = [255 * _ for _ in cmap(index)[:-1]] - cv2.rectangle( - image_crop_for_debug, - (region["x1"], region["y1"] - min_y1 + buffer), - (region["x2"], region["y2"] - min_y1 + buffer), - color, - 3, - ) - image_crop_for_debug = cv2.resize( - image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA - ) - util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - - image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) - util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - - gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) - util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: - logger.info("Loading IAM paragraph crops and ground truth from image files...") - images = [] - gt_images = [] - ids = [] - for filename in CROPS_DIRNAME.glob("*.jpg"): - id_ = filename.stem - image = util.read_image(filename, grayscale=True) - image = 1.0 - image / 255 - - gt_filename = GT_DIRNAME / f"{id_}.png" - gt_image = util.read_image(gt_filename, grayscale=True) - - images.append(image) - gt_images.append(gt_image) - ids.append(id_) - images = np.array(images).astype(np.float32) - gt_images = np.array(gt_images).astype(np.uint8) - ids = np.array(ids) - return images, gt_images, ids - - -@click.command() -@click.option( - "--subsample_fraction", - type=float, - default=None, - help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: - """Load dataset and print info.""" - logger.info("Creating train set...") - dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - logger.info("Creating test set...") - dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py deleted file mode 100644 index a93eb00..0000000 --- a/src/text_recognizer/datasets/iam_preprocessor.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Preprocessor for extracting word letters from the IAM dataset. - -The code is mostly stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - -""" - -import collections -import itertools -from pathlib import Path -import re -from typing import List, Optional, Union - -import click -from loguru import logger -import torch - - -def load_metadata( - data_dir: Path, wordsep: str, use_words: bool = False -) -> collections.defaultdict: - """Loads IAM metadata and returns it as a dictionary.""" - forms = collections.defaultdict(list) - filename = "words.txt" if use_words else "lines.txt" - - with open(data_dir / "ascii" / filename, "r") as f: - lines = (line.strip().split() for line in f if line[0] != "#") - for line in lines: - # Skip word segmentation errors. - if use_words and line[1] == "err": - continue - text = " ".join(line[8:]) - - # Remove garbage tokens: - text = text.replace("#", "") - - # Swap word sep form | to wordsep - text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) - form_key = "-".join(line[0].split("-")[:2]) - line_key = "-".join(line[0].split("-")[:3]) - box_idx = 4 - use_words - box = tuple(int(val) for val in line[box_idx : box_idx + 4]) - forms[form_key].append({"key": line_key, "box": box, "text": text}) - return forms - - -class Preprocessor: - """A preprocessor for the IAM dataset.""" - - # TODO: add lower case only to when generating... - - def __init__( - self, - data_dir: Union[str, Path], - num_features: int, - tokens_path: Optional[Union[str, Path]] = None, - lexicon_path: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - self.wordsep = "▁" - self._use_word = use_words - self._prepend_wordsep = prepend_wordsep - - self.data_dir = Path(data_dir) - - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) - - # Load the set of graphemes: - graphemes = set() - for _, form in self.forms.items(): - for line in form: - graphemes.update(line["text"].lower()) - self.graphemes = sorted(graphemes) - - # Build the token-to-index and index-to-token maps. - if tokens_path is not None: - with open(tokens_path, "r") as f: - self.tokens = [line.strip() for line in f] - else: - self.tokens = self.graphemes - - if lexicon_path is not None: - with open(lexicon_path, "r") as f: - lexicon = (line.strip().split() for line in f) - lexicon = {line[0]: line[1:] for line in lexicon} - self.lexicon = lexicon - else: - self.lexicon = None - - self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} - self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} - self.num_features = num_features - self.text = [] - - @property - def num_tokens(self) -> int: - """Returns the number or tokens.""" - return len(self.tokens) - - @property - def use_words(self) -> bool: - """If words are used.""" - return self._use_word - - def extract_train_text(self) -> None: - """Extracts training text.""" - keys = [] - with open(self.data_dir / "task" / "trainset.txt") as f: - keys.extend((line.strip() for line in f)) - - for _, examples in self.forms.items(): - for example in examples: - if example["key"] not in keys: - continue - self.text.append(example["text"].lower()) - - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" - token_to_index = self.graphemes_to_index - if self.lexicon is not None: - if len(line) > 0: - # If the word is not found in the lexicon, fall back to letters. - line = [ - t - for w in line.split(self.wordsep) - for t in self.lexicon.get(w, self.wordsep + w) - ] - token_to_index = self.tokens_to_index - if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) - - def to_text(self, indices: List[int]) -> str: - """Converts indices to text.""" - # Roughly the inverse of `to_index` - encoding = self.graphemes - if self.lexicon is not None: - encoding = self.tokens - return self._post_process(encoding[i] for i in indices) - - def tokens_to_text(self, indices: List[int]) -> str: - """Converts tokens to text.""" - return self._post_process(self.tokens[i] for i in indices) - - def _post_process(self, indices: List[int]) -> str: - """A list join.""" - return "".join(indices).strip(self.wordsep) - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") -@click.option( - "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" -) -@click.option( - "--save_text", type=str, default=None, help="Path to save parsed train text" -) -@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") -def cli( - data_dir: Optional[str], - use_words: bool, - save_text: Optional[str], - save_tokens: Optional[str], -) -> None: - """CLI for extracting text data from the iam dataset.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - preprocessor = Preprocessor(data_dir, 64, use_words=use_words) - preprocessor.extract_train_text() - - processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") - - if save_text is not None: - logger.info("Saving training text") - with open(processed_dir / save_text, "w") as f: - f.write("\n".join(t for t in preprocessor.text)) - - if save_tokens is not None: - logger.info("Saving tokens") - with open(processed_dir / save_tokens, "w") as f: - f.write("\n".join(preprocessor.tokens)) - - -if __name__ == "__main__": - cli() diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py deleted file mode 100644 index dd76652..0000000 --- a/src/text_recognizer/datasets/sentence_generator.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Downloading the Brown corpus with NLTK for sentence generating.""" - -import itertools -import re -import string -from typing import Optional - -import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView -import numpy as np - -from text_recognizer.datasets.util import DATA_DIRNAME - -NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk" - - -class SentenceGenerator: - """Generates text sentences using the Brown corpus.""" - - def __init__(self, max_length: Optional[int] = None) -> None: - """Loads the corpus and sets word start indices.""" - self.corpus = brown_corpus() - self.word_start_indices = [0] + [ - _.start(0) + 1 for _ in re.finditer(" ", self.corpus) - ] - self.max_length = max_length - - def generate(self, max_length: Optional[int] = None) -> str: - """Generates a word or sentences from the Brown corpus. - - Sample a string from the Brown corpus of length at least one word and at most max_length, padding to - max_length with the '_' characters if sentence is shorter. - - Args: - max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. - - Returns: - str: A sentence from the Brown corpus. - - Raises: - ValueError: If max_length was not specified at initialization and not given as an argument. - - """ - if max_length is None: - max_length = self.max_length - if max_length is None: - raise ValueError( - "Must provide max_length to this method or when making this object." - ) - - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - padding = "_" * (max_length - len(sampled_text)) - return sampled_text + padding - - -def brown_corpus() -> str: - """Returns a single string with the Brown corpus with all punctuations stripped.""" - sentences = load_nltk_brown_corpus() - corpus = " ".join(itertools.chain.from_iterable(sentences)) - corpus = corpus.translate({ord(c): None for c in string.punctuation}) - corpus = re.sub(" +", " ", corpus) - return corpus - - -def load_nltk_brown_corpus() -> ConcatenatedCorpusView: - """Load the Brown corpus using the NLTK library.""" - nltk.data.path.append(NLTK_DATA_DIRNAME) - try: - nltk.corpus.brown.sents() - except LookupError: - NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) - return nltk.corpus.brown.sents() diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py deleted file mode 100644 index b6a48f5..0000000 --- a/src/text_recognizer/datasets/transforms.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Transforms for PyTorch datasets.""" -from abc import abstractmethod -from pathlib import Path -import random -from typing import Any, Optional, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( - ColorJitter, - Compose, - Normalize, - RandomAffine, - RandomHorizontalFlip, - RandomRotation, - ToPILImage, - ToTensor, -) - -from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: - """Image transform with random resize and crop applied. - - Stolen from - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - - """ - - def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: - self.jitter = jitter - self.ratio = ratio - - def __call__(self, img: np.ndarray) -> np.ndarray: - """Applies random crop and rotation to an image.""" - w, h = img.size - - # pad with white: - img = transforms.functional.pad(img, self.jitter, fill=255) - - # crop at random (x, y): - x = self.jitter + random.randint(-self.jitter, self.jitter) - y = self.jitter + random.randint(-self.jitter, self.jitter) - - # randomize aspect ratio: - size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) - size = (h, int(size_w)) - img = transforms.functional.resized_crop(img, y, x, h, w, size) - return img - - -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - -class Resize: - """Resizes a tensor to a specified width.""" - - def __init__(self, width: int = 952) -> None: - # The default is 952 because of the IAM dataset. - self.width = width - - def __call__(self, image: Tensor) -> Tensor: - """Resize tensor in the last dimension.""" - return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: - """Adds start of sequence and end of sequence tokens to target tensor.""" - - def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, - ) - self.pad_value = self.emnist_mapper(self.pad_token) - self.eos_value = self.emnist_mapper(self.eos_token) - - def __call__(self, target: Tensor) -> Tensor: - """Adds a sos token to the begining and a eos token to the end of a target sequence.""" - dtype, device = target.dtype, target.device - - # Find the where padding starts. - pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - - target[pad_index] = self.eos_value - - if self.init_token is not None: - self.sos_value = self.emnist_mapper(self.init_token) - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) - target = torch.cat([sos, target], dim=0) - - return target - - -class ApplyContrast: - """Sets everything below a threshold to zero, i.e. increase contrast.""" - - def __init__(self, low: float = 0.0, high: float = 0.25) -> None: - self.low = low - self.high = high - - def __call__(self, x: Tensor) -> Tensor: - """Apply mask binary mask to input tensor.""" - mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) - return x * mask - - -class Unsqueeze: - """Add a dimension to the tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Adds dim.""" - return x.unsqueeze(0) - - -class Squeeze: - """Removes the first dimension of a tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Removes first dim.""" - return x.squeeze(0) - - -class ToLower: - """Converts target to lower case.""" - - def __call__(self, target: Tensor) -> Tensor: - """Corrects index value in target tensor.""" - device = target.device - return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) - - -class ToCharcters: - """Converts integers to characters.""" - - def __init__( - self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True - ) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=lower, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, lower=lower - ) - - def __call__(self, y: Tensor) -> str: - """Converts a Tensor to a str.""" - return ( - "".join([self.emnist_mapper(int(i)) for i in y]) - .strip("_") - .replace(" ", "▁") - ) - - -class WordPieces: - """Abstract transform for word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - self.preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - ) - - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """Transforms input.""" - ... - - -class ToWordPieces(WordPieces): - """Transforms str to word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, line: str) -> Tensor: - """Transforms str to word pieces.""" - return self.preprocessor.to_index(line) - - -class ToText(WordPieces): - """Takes word pieces and converts them to text.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, x: Tensor) -> str: - """Converts tensor to text.""" - return self.preprocessor.to_text(x.tolist()) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py deleted file mode 100644 index da87756..0000000 --- a/src/text_recognizer/datasets/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__( - self, - pad_token: str, - init_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - self.lower = lower - - self.essentials = self._load_emnist_essentials() - # Load dataset information. - self._mapping = dict(self.essentials["mapping"]) - self._augment_emnist_mapping() - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if ( - (isinstance(token, np.uint8) or isinstance(token, int)) - or torch.is_tensor(token) - and int(token) in self.mapping - ): - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self) -> None: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - if self.lower: - self._mapping = { - k: str(v) - for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) - } - - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol, and acts as blank symbol as well. - extra_symbols.append(self.pad_token) - - if self.init_token is not None: - extra_symbols.append(self.init_token) - - if self.eos_token is not None: - extra_symbols.append(self.eos_token) - - max_key = max(self.mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: - """Returns the SHA256 checksum of a file.""" - with open(filename, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: - if os.path.exists(metadata["filename"]): - return - logger.info(f"Downloading raw dataset from {metadata['url']}...") - download_url(metadata["url"], metadata["filename"]) - logger.info("Computing SHA-256...") - sha256 = compute_sha256(metadata["filename"]) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) |