summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/datasets
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r--text_recognizer/datasets/__init__.py39
-rw-r--r--text_recognizer/datasets/dataset.py152
-rw-r--r--text_recognizer/datasets/emnist_dataset.py131
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--text_recognizer/datasets/emnist_lines_dataset.py359
-rw-r--r--text_recognizer/datasets/iam_dataset.py133
-rw-r--r--text_recognizer/datasets/iam_lines_dataset.py110
-rw-r--r--text_recognizer/datasets/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--text_recognizer/datasets/sentence_generator.py81
-rw-r--r--text_recognizer/datasets/transforms.py266
-rw-r--r--text_recognizer/datasets/util.py209
12 files changed, 1968 insertions, 0 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
new file mode 100644
index 0000000..a6c1c59
--- /dev/null
+++ b/text_recognizer/datasets/__init__.py
@@ -0,0 +1,39 @@
+"""Dataset modules."""
+from .emnist_dataset import EmnistDataset
+from .emnist_lines_dataset import (
+ construct_image_from_string,
+ EmnistLinesDataset,
+ get_samples_by_character,
+)
+from .iam_dataset import IamDataset
+from .iam_lines_dataset import IamLinesDataset
+from .iam_paragraphs_dataset import IamParagraphsDataset
+from .iam_preprocessor import load_metadata, Preprocessor
+from .transforms import AddTokens, Transpose
+from .util import (
+ _download_raw_dataset,
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
+
+__all__ = [
+ "_download_raw_dataset",
+ "AddTokens",
+ "compute_sha256",
+ "construct_image_from_string",
+ "DATA_DIRNAME",
+ "download_url",
+ "EmnistDataset",
+ "EmnistMapper",
+ "EmnistLinesDataset",
+ "get_samples_by_character",
+ "load_metadata",
+ "IamDataset",
+ "IamLinesDataset",
+ "IamParagraphsDataset",
+ "Preprocessor",
+ "Transpose",
+]
diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py
new file mode 100644
index 0000000..e794605
--- /dev/null
+++ b/text_recognizer/datasets/dataset.py
@@ -0,0 +1,152 @@
+"""Abstract dataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils import data
+from torchvision.transforms import ToTensor
+
+import text_recognizer.datasets.transforms as transforms
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class Dataset(data.Dataset):
+ """Abstract class for with common methods for all datasets."""
+
+ def __init__(
+ self,
+ train: bool,
+ subsample_fraction: float = None,
+ transform: Optional[List[Dict]] = None,
+ target_transform: Optional[List[Dict]] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Initialization of Dataset class.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
+ transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
+ target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): Only use lower case letters. Defaults to False.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+ self.train = train
+ self.split = "train" if self.train else "test"
+
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
+ )
+ self._input_shape = self._mapper.input_shape
+ self._output_shape = self._mapper._num_classes
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = self._configure_transform(transform)
+ self.target_transform = self._configure_target_transform(target_transform)
+
+ self._data = None
+ self._targets = None
+
+ def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
+ transform_list = []
+ if transform is not None:
+ for t in transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ transform_list.append(getattr(transforms, t_type)(**t_args))
+ else:
+ transform_list.append(ToTensor())
+ return transforms.Compose(transform_list)
+
+ def _configure_target_transform(
+ self, target_transform: List[Dict]
+ ) -> transforms.Compose:
+ target_transform_list = [torch.tensor]
+ if target_transform is not None:
+ for t in target_transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ target_transform_list.append(getattr(transforms, t_type)(**t_args))
+ return transforms.Compose(target_transform_list)
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self._data = self.data[:num_subsample]
+ self._targets = self.targets[:num_subsample]
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ raise NotImplementedError
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Raises:
+ NotImplementedError: If the method is not implemented in child class.
+
+ """
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ raise NotImplementedError
diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py
new file mode 100644
index 0000000..9884fdf
--- /dev/null
+++ b/text_recognizer/datasets/emnist_dataset.py
@@ -0,0 +1,131 @@
+"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
+
+import json
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+from torchvision.datasets import EMNIST
+from torchvision.transforms import Compose, ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.transforms import Transpose
+from text_recognizer.datasets.util import DATA_DIRNAME
+
+
+class EmnistDataset(Dataset):
+ """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
+
+ def __init__(
+ self,
+ pad_token: str = None,
+ train: bool = False,
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ seed: int = 4711,
+ ) -> None:
+ """Loads the dataset and the mappings.
+
+ Args:
+ pad_token (str): The pad token symbol. Defaults to _.
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ seed (int): Seed number. Defaults to 4711.
+
+ """
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ pad_token=pad_token,
+ )
+
+ self.sample_to_balance = sample_to_balance
+
+ # Have to transpose the emnist characters, ToTensor norms input between [0,1].
+ if transform is None:
+ self.transform = Compose([Transpose(), ToTensor()])
+
+ self.target_transform = None
+
+ self.seed = seed
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, Tensor]): The indices of the samples to fetch.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target tuple.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Dataset\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Input shape: {self.input_shape}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ )
+
+ def _sample_to_balance(self) -> None:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = self._data
+ y = self._targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ self._data = x_sampled
+ self._targets = y_sampled
+
+ def load_or_generate_data(self) -> None:
+ """Fetch the EMNIST dataset."""
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=self.train,
+ download=False,
+ transform=None,
+ target_transform=None,
+ )
+
+ self._data = dataset.data
+ self._targets = dataset.targets
+
+ if self.sample_to_balance:
+ self._sample_to_balance()
+
+ if self.subsample_fraction is not None:
+ self._subsample()
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
new file mode 100644
index 0000000..2a0648a
--- /dev/null
+++ b/text_recognizer/datasets/emnist_essentials.json
@@ -0,0 +1 @@
+{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]}
diff --git a/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py
new file mode 100644
index 0000000..1992446
--- /dev/null
+++ b/text_recognizer/datasets/emnist_lines_dataset.py
@@ -0,0 +1,359 @@
+"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters."""
+
+from collections import defaultdict
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+from text_recognizer.datasets.util import (
+ DATA_DIRNAME,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
+
+DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
+
+MAX_WIDTH = 952
+
+
+class EmnistLinesDataset(Dataset):
+ """Synthetic dataset of lines from the Brown corpus with Emnist characters."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ subsample_fraction: float = None,
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_samples: int = 10000,
+ seed: int = 4711,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Set attributes and loads the dataset.
+
+ Args:
+ train (bool): Flag for the filename. Defaults to False. Defaults to None.
+ transform (Optional[Callable]): The transform of the data. Defaults to None.
+ target_transform (Optional[Callable]): The transform of the target. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
+ max_length (int): The maximum number of characters. Defaults to 34.
+ min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
+ max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
+ num_samples (int): Number of samples to generate. Defaults to 10000.
+ seed (int): Seed number. Defaults to 4711.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
+
+ """
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ transform=transform,
+ target_transform=target_transform,
+ subsample_fraction=subsample_fraction,
+ init_token=init_token,
+ pad_token=self.pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ # Extract dataset information.
+ self._input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
+
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_samples = num_samples
+ self._input_shape = (
+ self.input_shape[0],
+ self.input_shape[1] * self.max_length,
+ )
+ self._output_shape = (self.max_length, self.num_classes)
+ self.seed = seed
+
+ # Placeholders for the dataset.
+ self._data = None
+ self._target = None
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return (
+ "EMNIST Lines Dataset\n" # pylint: disable=no-member
+ f"Max length: {self.max_length}\n"
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {self.num_classes}\n"
+ f"Input shape: {self.input_shape}\n"
+ f"Data: {self.data.shape}\n"
+ f"Tagets: {self.targets.shape}\n"
+ )
+
+ @property
+ def data_filename(self) -> Path:
+ """Path to the h5 file."""
+ filename = "train.pt" if self.train else "test.pt"
+ return DATA_DIRNAME / filename
+
+ def load_or_generate_data(self) -> None:
+ """Loads the dataset, if it does not exist a new dataset is generated before loading it."""
+ np.random.seed(self.seed)
+
+ if not self.data_filename.exists():
+ self._generate_data()
+ self._load_data()
+ self._subsample()
+
+ def _load_data(self) -> None:
+ """Loads the dataset from the h5 file."""
+ logger.debug("EmnistLinesDataset loading data from HDF5...")
+ with h5py.File(self.data_filename, "r") as f:
+ self._data = f["data"][()]
+ self._targets = f["targets"][()]
+
+ def _generate_data(self) -> str:
+ """Generates a dataset with the Brown corpus and Emnist characters."""
+ logger.debug("Generating data...")
+
+ sentence_generator = SentenceGenerator(self.max_length)
+
+ # Load emnist dataset.
+ emnist = EmnistDataset(
+ train=self.train, sample_to_balance=True, pad_token=self.pad_token
+ )
+ emnist.load_or_generate_data()
+
+ samples_by_character = get_samples_by_character(
+ emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
+ )
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "a") as f:
+ data, targets = create_dataset_of_images(
+ self.num_samples,
+ samples_by_character,
+ sentence_generator,
+ self.min_overlap,
+ self.max_overlap,
+ )
+
+ targets = convert_strings_to_categorical_labels(
+ targets, emnist.inverse_mapping
+ )
+
+ f.create_dataset("data", data=data, dtype="u1", compression="lzf")
+ f.create_dataset("targets", data=targets, dtype="u1", compression="lzf")
+
+
+def get_samples_by_character(
+ samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
+ """Creates a dictionary with character as key and value as the list of images of that character.
+
+ Args:
+ samples (np.ndarray): Dataset of images of characters.
+ labels (np.ndarray): The labels for each image.
+ mapping (Dict): The Emnist mapping dictionary.
+
+ Returns:
+ defaultdict: A dictionary with characters as keys and list of images as values.
+
+ """
+ samples_by_character = defaultdict(list)
+ for sample, label in zip(samples, labels.flatten()):
+ samples_by_character[mapping[label]].append(sample)
+ return samples_by_character
+
+
+def select_letter_samples_for_string(
+ string: str, samples_by_character: Dict
+) -> List[np.ndarray]:
+ """Randomly selects Emnist characters to use for the senetence.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+
+ Returns:
+ List[np.ndarray]: A list of emnist images of the string.
+
+ """
+ zero_image = np.zeros((28, 28), np.uint8)
+ sample_image_by_character = {}
+ for character in string:
+ if character in sample_image_by_character:
+ continue
+ samples = samples_by_character[character]
+ sample = samples[np.random.choice(len(samples))] if samples else zero_image
+ sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1)
+ return [sample_image_by_character[character] for character in string]
+
+
+def construct_image_from_string(
+ string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float
+) -> np.ndarray:
+ """Concatenates images of the characters in the string.
+
+ The concatination is made with randomly selected overlap so that some portion of the character will overlap.
+
+ Args:
+ string (str): The word or sentence.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ np.ndarray: The Emnist image of the string.
+
+ """
+ overlap = np.random.uniform(min_overlap, max_overlap)
+ sampled_images = select_letter_samples_for_string(string, samples_by_character)
+ length = len(sampled_images)
+ height, width = sampled_images[0].shape
+ next_overlap_width = width - int(overlap * width)
+ concatenated_image = np.zeros((height, width * length), np.uint8)
+ x = 0
+ for image in sampled_images:
+ concatenated_image[:, x : (x + width)] += image
+ x += next_overlap_width
+
+ if concatenated_image.shape[-1] > MAX_WIDTH:
+ concatenated_image = Tensor(concatenated_image).unsqueeze(0)
+ concatenated_image = F.interpolate(
+ concatenated_image, size=MAX_WIDTH, mode="nearest"
+ )
+ concatenated_image = concatenated_image.squeeze(0).numpy()
+
+ return np.minimum(255, concatenated_image)
+
+
+def create_dataset_of_images(
+ length: int,
+ samples_by_character: Dict,
+ sentence_generator: SentenceGenerator,
+ min_overlap: float,
+ max_overlap: float,
+) -> Tuple[np.ndarray, List[str]]:
+ """Creates a dataset with images and labels from strings generated from the SentenceGenerator.
+
+ Args:
+ length (int): The number of characters for each string.
+ samples_by_character (Dict): The dictionary of emnist images of each character.
+ sentence_generator (SentenceGenerator): A SentenceGenerator objest.
+ min_overlap (float): Minimum amount of overlap between Emnist images.
+ max_overlap (float): Maximum amount of overlap between Emnist images.
+
+ Returns:
+ Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels).
+
+ Raises:
+ RuntimeError: If the sentence generator is not able to generate a string.
+
+ """
+ sample_label = sentence_generator.generate()
+ sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0)
+ images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8)
+ labels = []
+ for n in range(length):
+ label = None
+ # Try several times to generate before actually throwing an error.
+ for _ in range(10):
+ try:
+ label = sentence_generator.generate()
+ break
+ except Exception: # pylint: disable=broad-except
+ pass
+ if label is None:
+ raise RuntimeError("Was not able to generate a valid string.")
+ images[n] = construct_image_from_string(
+ label, samples_by_character, min_overlap, max_overlap
+ )
+ labels.append(label)
+ return images, labels
+
+
+def convert_strings_to_categorical_labels(
+ labels: List[str], mapping: Dict
+) -> np.ndarray:
+ """Translates a string of characters in to a target array of class int."""
+ return np.array([[mapping[c] for c in label] for label in labels])
+
+
+@click.command()
+@click.option(
+ "--max_length", type=int, default=34, help="Number of characters in a sentence."
+)
+@click.option(
+ "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
+)
+@click.option(
+ "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
+)
+@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
+@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
+def create_datasets(
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_train: int = 10000,
+ num_test: int = 1000,
+) -> None:
+ """Creates a training an validation dataset of Emnist lines."""
+ num_samples = [num_train, num_test]
+ for num, train in zip(num_samples, [True, False]):
+ emnist_lines = EmnistLinesDataset(
+ train=train,
+ max_length=max_length,
+ min_overlap=min_overlap,
+ max_overlap=max_overlap,
+ num_samples=num,
+ )
+ emnist_lines.load_or_generate_data()
+
+
+if __name__ == "__main__":
+ create_datasets()
diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py
new file mode 100644
index 0000000..a8998b9
--- /dev/null
+++ b/text_recognizer/datasets/iam_dataset.py
@@ -0,0 +1,133 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from typing import Any, Dict, List
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+import defusedxml.ElementTree as ET
+from loguru import logger
+import toml
+
+from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
+
+RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
+
+
+class IamDataset:
+ """IAM dataset.
+
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+
+ """
+
+ def __init__(self) -> None:
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def load_or_generate_data(self) -> None:
+ """Downloads IAM dataset if xml files does not exist."""
+ if not self.xml_filenames:
+ self._download_iam()
+
+ @property
+ def xml_filenames(self) -> List:
+ """List of xml filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List:
+ """List of forms filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ def _download_iam(self) -> None:
+ curdir = os.getcwd()
+ os.chdir(RAW_DATA_DIRNAME)
+ _download_raw_dataset(self.metadata)
+ _extract_raw_dataset(self.metadata)
+ os.chdir(curdir)
+
+ @property
+ def form_filenames_by_id(self) -> Dict:
+ """Creates a dictionary with filenames as keys and forms as values."""
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of line texts in it."""
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ def __repr__(self) -> str:
+ """Print info about dataset."""
+ return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
+
+
+def _extract_raw_dataset(metadata: Dict) -> None:
+ logger.info("Extracting IAM data.")
+ with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
+ zip_file.extractall()
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace &quot; with "."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get the line region dict for each line."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
+ """Extracts coordinates for each line of text."""
+ # TODO: fix input!
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def main() -> None:
+ """Initializes the dataset and print info about the dataset."""
+ dataset = IamDataset()
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py
new file mode 100644
index 0000000..1cb84bd
--- /dev/null
+++ b/text_recognizer/datasets/iam_lines_dataset.py
@@ -0,0 +1,110 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:] if self.data is not None else None
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py
new file mode 100644
index 0000000..8ba5142
--- /dev/null
+++ b/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -0,0 +1,291 @@
+"""IamParagraphsDataset class and functions for data processing."""
+import random
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import cv2
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer import util
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.iam_dataset import IamDataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
+DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
+CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
+GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
+
+PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
+TEST_FRACTION = 0.2
+SEED = 4711
+
+
+class IamParagraphsDataset(Dataset):
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
+ # Load Iam dataset.
+ self.iam_dataset = IamDataset()
+
+ self.num_classes = 3
+ self._input_shape = (256, 256)
+ self._output_shape = self._input_shape + (self.num_classes,)
+ self._ids = None
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ seed = np.random.randint(SEED)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.transform:
+ data = self.transform(data)
+
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets.long()
+
+ @property
+ def ids(self) -> Tensor:
+ """Ids of the dataset."""
+ return self._ids
+
+ def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
+ """Get data target pair from id."""
+ ind = self.ids.index(id_)
+ return self.data[ind], self.targets[ind]
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
+ num_targets = len(self.iam_dataset.line_regions_by_id)
+
+ if num_actual < num_targets - 2:
+ self._process_iam_paragraphs()
+
+ self._data, self._targets, self._ids = _load_iam_paragraphs()
+ self._get_random_split()
+ self._subsample()
+
+ def _get_random_split(self) -> None:
+ np.random.seed(SEED)
+ num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
+ indices = np.random.permutation(self.data.shape[0])
+ train_indices, test_indices = indices[:num_train], indices[num_train:]
+ if self.train:
+ self._data = self.data[train_indices]
+ self._targets = self.targets[train_indices]
+ else:
+ self._data = self.data[test_indices]
+ self._targets = self.targets[test_indices]
+
+ def _process_iam_paragraphs(self) -> None:
+ """Crop the part with the text.
+
+ For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
+ self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
+ corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
+ """
+ crop_dims = self._decide_on_crop_dims()
+ CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ GT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
+ )
+ for filename in self.iam_dataset.form_filenames:
+ id_ = filename.stem
+ line_region = self.iam_dataset.line_regions_by_id[id_]
+ _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
+
+ def _decide_on_crop_dims(self) -> Tuple[int, int]:
+ """Decide on the dimensions to crop out of the form image.
+
+ Since image width is larger than a comfortable crop around the longest paragraph,
+ we will make the crop a square form factor.
+ And since the found dimensions 610x610 are pretty close to 512x512,
+ we might as well resize crops and make it exactly that, which lets us
+ do all kinds of power-of-2 pooling and upsampling should we choose to.
+
+ Returns:
+ Tuple[int, int]: A tuple of crop dimensions.
+
+ Raises:
+ RuntimeError: When max crop height is larger than max crop width.
+
+ """
+
+ sample_form_filename = self.iam_dataset.form_filenames[0]
+ sample_image = util.read_image(sample_form_filename, grayscale=True)
+ max_crop_width = sample_image.shape[1]
+ max_crop_height = _get_max_paragraph_crop_height(
+ self.iam_dataset.line_regions_by_id
+ )
+ if not max_crop_height <= max_crop_width:
+ raise RuntimeError(
+ f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
+ )
+
+ crop_dims = (max_crop_width, max_crop_width)
+ logger.info(
+ f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
+ )
+ logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
+ return crop_dims
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return (
+ "IAM Paragraph Dataset\n" # pylint: disable=no-member
+ f"Num classes: {self.num_classes}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+
+def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
+ heights = []
+ for regions in line_regions_by_id.values():
+ min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ heights.append(height)
+ return max(heights)
+
+
+def _crop_paragraph_image(
+ filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
+) -> None:
+ image = util.read_image(filename, grayscale=True)
+
+ min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ crop_height = crop_dims[0]
+ buffer = (crop_height - height) // 2
+
+ # Generate image crop.
+ image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
+ try:
+ image_crop[buffer : buffer + height] = image[min_y1:max_y2]
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(f"Rescued {filename}: {e}")
+ return
+
+ # Generate ground truth.
+ gt_image = np.zeros_like(image_crop, dtype=np.uint8)
+ for index, region in enumerate(line_regions):
+ gt_image[
+ (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
+ region["x1"] : region["x2"],
+ ] = (index % 2 + 1)
+
+ # Generate image for debugging.
+ import matplotlib.pyplot as plt
+
+ cmap = plt.get_cmap("Set1")
+ image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
+ for index, region in enumerate(line_regions):
+ color = [255 * _ for _ in cmap(index)[:-1]]
+ cv2.rectangle(
+ image_crop_for_debug,
+ (region["x1"], region["y1"] - min_y1 + buffer),
+ (region["x2"], region["y2"] - min_y1 + buffer),
+ color,
+ 3,
+ )
+ image_crop_for_debug = cv2.resize(
+ image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
+ )
+ util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
+ util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
+ util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
+
+
+def _load_iam_paragraphs() -> None:
+ logger.info("Loading IAM paragraph crops and ground truth from image files...")
+ images = []
+ gt_images = []
+ ids = []
+ for filename in CROPS_DIRNAME.glob("*.jpg"):
+ id_ = filename.stem
+ image = util.read_image(filename, grayscale=True)
+ image = 1.0 - image / 255
+
+ gt_filename = GT_DIRNAME / f"{id_}.png"
+ gt_image = util.read_image(gt_filename, grayscale=True)
+
+ images.append(image)
+ gt_images.append(gt_image)
+ ids.append(id_)
+ images = np.array(images).astype(np.float32)
+ gt_images = np.array(gt_images).astype(np.uint8)
+ ids = np.array(ids)
+ return images, gt_images, ids
+
+
+@click.command()
+@click.option(
+ "--subsample_fraction",
+ type=float,
+ default=None,
+ help="The subsampling factor of the dataset.",
+)
+def main(subsample_fraction: float) -> None:
+ """Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
+ dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py
new file mode 100644
index 0000000..a93eb00
--- /dev/null
+++ b/text_recognizer/datasets/iam_preprocessor.py
@@ -0,0 +1,196 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ # TODO: add lower case only to when generating...
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ self.wordsep = "▁"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+
+ self.data_dir = Path(data_dir)
+
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ line = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ line = itertools.chain([self.wordsep], line)
+ return torch.LongTensor([token_to_index[t] for t in line])
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ logger.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ logger.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ logger.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
new file mode 100644
index 0000000..dd76652
--- /dev/null
+++ b/text_recognizer/datasets/sentence_generator.py
@@ -0,0 +1,81 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.datasets.util import DATA_DIRNAME
+
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
+
+
+class SentenceGenerator:
+ """Generates text sentences using the Brown corpus."""
+
+ def __init__(self, max_length: Optional[int] = None) -> None:
+ """Loads the corpus and sets word start indices."""
+ self.corpus = brown_corpus()
+ self.word_start_indices = [0] + [
+ _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+ ]
+ self.max_length = max_length
+
+ def generate(self, max_length: Optional[int] = None) -> str:
+ """Generates a word or sentences from the Brown corpus.
+
+ Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
+ max_length with the '_' characters if sentence is shorter.
+
+ Args:
+ max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
+
+ Returns:
+ str: A sentence from the Brown corpus.
+
+ Raises:
+ ValueError: If max_length was not specified at initialization and not given as an argument.
+
+ """
+ if max_length is None:
+ max_length = self.max_length
+ if max_length is None:
+ raise ValueError(
+ "Must provide max_length to this method or when making this object."
+ )
+
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ padding = "_" * (max_length - len(sampled_text))
+ return sampled_text + padding
+
+
+def brown_corpus() -> str:
+ """Returns a single string with the Brown corpus with all punctuations stripped."""
+ sentences = load_nltk_brown_corpus()
+ corpus = " ".join(itertools.chain.from_iterable(sentences))
+ corpus = corpus.translate({ord(c): None for c in string.punctuation})
+ corpus = re.sub(" +", " ", corpus)
+ return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+ """Load the Brown corpus using the NLTK library."""
+ nltk.data.path.append(NLTK_DATA_DIRNAME)
+ try:
+ nltk.corpus.brown.sents()
+ except LookupError:
+ NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+ return nltk.corpus.brown.sents()
diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py
new file mode 100644
index 0000000..b6a48f5
--- /dev/null
+++ b/text_recognizer/datasets/transforms.py
@@ -0,0 +1,266 @@
+"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
+import random
+from typing import Any, Optional, Union
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision import transforms
+from torchvision.transforms import (
+ ColorJitter,
+ Compose,
+ Normalize,
+ RandomAffine,
+ RandomHorizontalFlip,
+ RandomRotation,
+ ToPILImage,
+ ToTensor,
+)
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class RandomResizeCrop:
+ """Image transform with random resize and crop applied.
+
+ Stolen from
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+ """
+
+ def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
+ self.jitter = jitter
+ self.ratio = ratio
+
+ def __call__(self, img: np.ndarray) -> np.ndarray:
+ """Applies random crop and rotation to an image."""
+ w, h = img.size
+
+ # pad with white:
+ img = transforms.functional.pad(img, self.jitter, fill=255)
+
+ # crop at random (x, y):
+ x = self.jitter + random.randint(-self.jitter, self.jitter)
+ y = self.jitter + random.randint(-self.jitter, self.jitter)
+
+ # randomize aspect ratio:
+ size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
+ size = (h, int(size_w))
+ img = transforms.functional.resized_crop(img, y, x, h, w, size)
+ return img
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
+
+
+class Resize:
+ """Resizes a tensor to a specified width."""
+
+ def __init__(self, width: int = 952) -> None:
+ # The default is 952 because of the IAM dataset.
+ self.width = width
+
+ def __call__(self, image: Tensor) -> Tensor:
+ """Resize tensor in the last dimension."""
+ return F.interpolate(image, size=self.width, mode="nearest")
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
+ return target
+
+
+class ApplyContrast:
+ """Sets everything below a threshold to zero, i.e. increase contrast."""
+
+ def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
+ self.low = low
+ self.high = high
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Apply mask binary mask to input tensor."""
+ mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
+ return x * mask
+
+
+class Unsqueeze:
+ """Add a dimension to the tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Adds dim."""
+ return x.unsqueeze(0)
+
+
+class Squeeze:
+ """Removes the first dimension of a tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Removes first dim."""
+ return x.squeeze(0)
+
+
+class ToLower:
+ """Converts target to lower case."""
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Corrects index value in target tensor."""
+ device = target.device
+ return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())
diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py
new file mode 100644
index 0000000..da87756
--- /dev/null
+++ b/text_recognizer/datasets/util.py
@@ -0,0 +1,209 @@
+"""Util functions for datasets."""
+import hashlib
+import json
+import os
+from pathlib import Path
+import string
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.datasets import EMNIST
+from tqdm import tqdm
+
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+
+
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
+
+
+def download_emnist() -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
+
+
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ self.lower = lower
+
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset information.
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self) -> None:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol, and acts as blank symbol as well.
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
+
+ max_key = max(self.mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ self._mapping = {**self.mapping, **extra_mapping}
+
+
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )