summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py39
-rw-r--r--src/text_recognizer/datasets/dataset.py152
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py131
-rw-r--r--src/text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py359
-rw-r--r--src/text_recognizer/datasets/iam_dataset.py132
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py110
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py291
-rw-r--r--src/text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--src/text_recognizer/datasets/sentence_generator.py81
-rw-r--r--src/text_recognizer/datasets/transforms.py266
-rw-r--r--src/text_recognizer/datasets/util.py209
12 files changed, 0 insertions, 1967 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
deleted file mode 100644
index a6c1c59..0000000
--- a/src/text_recognizer/datasets/__init__.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""Dataset modules."""
-from .emnist_dataset import EmnistDataset
-from .emnist_lines_dataset import (
- construct_image_from_string,
- EmnistLinesDataset,
- get_samples_by_character,
-)
-from .iam_dataset import IamDataset
-from .iam_lines_dataset import IamLinesDataset
-from .iam_paragraphs_dataset import IamParagraphsDataset
-from .iam_preprocessor import load_metadata, Preprocessor
-from .transforms import AddTokens, Transpose
-from .util import (
- _download_raw_dataset,
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
-
-__all__ = [
- "_download_raw_dataset",
- "AddTokens",
- "compute_sha256",
- "construct_image_from_string",
- "DATA_DIRNAME",
- "download_url",
- "EmnistDataset",
- "EmnistMapper",
- "EmnistLinesDataset",
- "get_samples_by_character",
- "load_metadata",
- "IamDataset",
- "IamLinesDataset",
- "IamParagraphsDataset",
- "Preprocessor",
- "Transpose",
-]
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
deleted file mode 100644
index e794605..0000000
--- a/src/text_recognizer/datasets/dataset.py
+++ /dev/null
@@ -1,152 +0,0 @@
-"""Abstract dataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch.utils import data
-from torchvision.transforms import ToTensor
-
-import text_recognizer.datasets.transforms as transforms
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class Dataset(data.Dataset):
- """Abstract class for with common methods for all datasets."""
-
- def __init__(
- self,
- train: bool,
- subsample_fraction: float = None,
- transform: Optional[List[Dict]] = None,
- target_transform: Optional[List[Dict]] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Initialization of Dataset class.
-
- Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
- target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
- init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
- pad_token (Optional[str]): String representing the pad token. Defaults to None.
- eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
- lower (bool): Only use lower case letters. Defaults to False.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.train = train
- self.split = "train" if self.train else "test"
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
- )
- self._input_shape = self._mapper.input_shape
- self._output_shape = self._mapper._num_classes
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = self._configure_transform(transform)
- self.target_transform = self._configure_target_transform(target_transform)
-
- self._data = None
- self._targets = None
-
- def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
- transform_list = []
- if transform is not None:
- for t in transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- transform_list.append(getattr(transforms, t_type)(**t_args))
- else:
- transform_list.append(ToTensor())
- return transforms.Compose(transform_list)
-
- def _configure_target_transform(
- self, target_transform: List[Dict]
- ) -> transforms.Compose:
- target_transform_list = [torch.tensor]
- if target_transform is not None:
- for t in target_transform:
- t_type = t["type"]
- t_args = t["args"] or {}
- target_transform_list.append(getattr(transforms, t_type)(**t_args))
- return transforms.Compose(target_transform_list)
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self._data = self.data[:num_subsample]
- self._targets = self.targets[:num_subsample]
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- raise NotImplementedError
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches samples from the dataset.
-
- Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
-
- Raises:
- NotImplementedError: If the method is not implemented in child class.
-
- """
- raise NotImplementedError
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
deleted file mode 100644
index 9884fdf..0000000
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""
-
-import json
-from pathlib import Path
-from typing import Callable, Optional, Tuple, Union
-
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.transforms import Transpose
-from text_recognizer.datasets.util import DATA_DIRNAME
-
-
-class EmnistDataset(Dataset):
- """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""
-
- def __init__(
- self,
- pad_token: str = None,
- train: bool = False,
- sample_to_balance: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- seed: int = 4711,
- ) -> None:
- """Loads the dataset and the mappings.
-
- Args:
- pad_token (str): The pad token symbol. Defaults to _.
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
- subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
- transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
- target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
- seed (int): Seed number. Defaults to 4711.
-
- """
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- pad_token=pad_token,
- )
-
- self.sample_to_balance = sample_to_balance
-
- # Have to transpose the emnist characters, ToTensor norms input between [0,1].
- if transform is None:
- self.transform = Compose([Transpose(), ToTensor()])
-
- self.target_transform = None
-
- self.seed = seed
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches samples from the dataset.
-
- Args:
- index (Union[int, Tensor]): The indices of the samples to fetch.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target tuple.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return (
- "EMNIST Dataset\n"
- f"Num classes: {self.num_classes}\n"
- f"Input shape: {self.input_shape}\n"
- f"Mapping: {self.mapper.mapping}\n"
- )
-
- def _sample_to_balance(self) -> None:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(self.seed)
- x = self._data
- y = self._targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_indices = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_indices.append(sampled_indices)
- indices = np.concatenate(all_sampled_indices)
- x_sampled = x[indices]
- y_sampled = y[indices]
- self._data = x_sampled
- self._targets = y_sampled
-
- def load_or_generate_data(self) -> None:
- """Fetch the EMNIST dataset."""
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=self.train,
- download=False,
- transform=None,
- target_transform=None,
- )
-
- self._data = dataset.data
- self._targets = dataset.targets
-
- if self.sample_to_balance:
- self._sample_to_balance()
-
- if self.subsample_fraction is not None:
- self._subsample()
diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json
deleted file mode 100644
index 2a0648a..0000000
--- a/src/text_recognizer/datasets/emnist_essentials.json
+++ /dev/null
@@ -1 +0,0 @@
-{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]}
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
deleted file mode 100644
index 1992446..0000000
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ /dev/null
@@ -1,359 +0,0 @@
-"""Emnist Lines dataset: synthetic handwritten lines dataset made from Emnist characters."""
-
-from collections import defaultdict
-from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import click
-import h5py
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
-from text_recognizer.datasets.sentence_generator import SentenceGenerator
-from text_recognizer.datasets.util import (
- DATA_DIRNAME,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
-
-DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
-
-MAX_WIDTH = 952
-
-
-class EmnistLinesDataset(Dataset):
- """Synthetic dataset of lines from the Brown corpus with Emnist characters."""
-
- def __init__(
- self,
- train: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- subsample_fraction: float = None,
- max_length: int = 34,
- min_overlap: float = 0,
- max_overlap: float = 0.33,
- num_samples: int = 10000,
- seed: int = 4711,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Set attributes and loads the dataset.
-
- Args:
- train (bool): Flag for the filename. Defaults to False. Defaults to None.
- transform (Optional[Callable]): The transform of the data. Defaults to None.
- target_transform (Optional[Callable]): The transform of the target. Defaults to None.
- subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- max_length (int): The maximum number of characters. Defaults to 34.
- min_overlap (float): The minimum overlap between concatenated images. Defaults to 0.
- max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33.
- num_samples (int): Number of samples to generate. Defaults to 10000.
- seed (int): Seed number. Defaults to 4711.
- init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
- pad_token (Optional[str]): String representing the pad token. Defaults to None.
- eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
- lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
-
- """
- self.pad_token = "_" if pad_token is None else pad_token
-
- super().__init__(
- train=train,
- transform=transform,
- target_transform=target_transform,
- subsample_fraction=subsample_fraction,
- init_token=init_token,
- pad_token=self.pad_token,
- eos_token=eos_token,
- lower=lower,
- )
-
- # Extract dataset information.
- self._input_shape = self._mapper.input_shape
- self.num_classes = self._mapper.num_classes
-
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_samples = num_samples
- self._input_shape = (
- self.input_shape[0],
- self.input_shape[1] * self.max_length,
- )
- self._output_shape = (self.max_length, self.num_classes)
- self.seed = seed
-
- # Placeholders for the dataset.
- self._data = None
- self._target = None
-
- def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return (
- "EMNIST Lines Dataset\n" # pylint: disable=no-member
- f"Max length: {self.max_length}\n"
- f"Min overlap: {self.min_overlap}\n"
- f"Max overlap: {self.max_overlap}\n"
- f"Num classes: {self.num_classes}\n"
- f"Input shape: {self.input_shape}\n"
- f"Data: {self.data.shape}\n"
- f"Tagets: {self.targets.shape}\n"
- )
-
- @property
- def data_filename(self) -> Path:
- """Path to the h5 file."""
- filename = "train.pt" if self.train else "test.pt"
- return DATA_DIRNAME / filename
-
- def load_or_generate_data(self) -> None:
- """Loads the dataset, if it does not exist a new dataset is generated before loading it."""
- np.random.seed(self.seed)
-
- if not self.data_filename.exists():
- self._generate_data()
- self._load_data()
- self._subsample()
-
- def _load_data(self) -> None:
- """Loads the dataset from the h5 file."""
- logger.debug("EmnistLinesDataset loading data from HDF5...")
- with h5py.File(self.data_filename, "r") as f:
- self._data = f["data"][()]
- self._targets = f["targets"][()]
-
- def _generate_data(self) -> str:
- """Generates a dataset with the Brown corpus and Emnist characters."""
- logger.debug("Generating data...")
-
- sentence_generator = SentenceGenerator(self.max_length)
-
- # Load emnist dataset.
- emnist = EmnistDataset(
- train=self.train, sample_to_balance=True, pad_token=self.pad_token
- )
- emnist.load_or_generate_data()
-
- samples_by_character = get_samples_by_character(
- emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
- )
-
- DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(self.data_filename, "a") as f:
- data, targets = create_dataset_of_images(
- self.num_samples,
- samples_by_character,
- sentence_generator,
- self.min_overlap,
- self.max_overlap,
- )
-
- targets = convert_strings_to_categorical_labels(
- targets, emnist.inverse_mapping
- )
-
- f.create_dataset("data", data=data, dtype="u1", compression="lzf")
- f.create_dataset("targets", data=targets, dtype="u1", compression="lzf")
-
-
-def get_samples_by_character(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
-) -> defaultdict:
- """Creates a dictionary with character as key and value as the list of images of that character.
-
- Args:
- samples (np.ndarray): Dataset of images of characters.
- labels (np.ndarray): The labels for each image.
- mapping (Dict): The Emnist mapping dictionary.
-
- Returns:
- defaultdict: A dictionary with characters as keys and list of images as values.
-
- """
- samples_by_character = defaultdict(list)
- for sample, label in zip(samples, labels.flatten()):
- samples_by_character[mapping[label]].append(sample)
- return samples_by_character
-
-
-def select_letter_samples_for_string(
- string: str, samples_by_character: Dict
-) -> List[np.ndarray]:
- """Randomly selects Emnist characters to use for the senetence.
-
- Args:
- string (str): The word or sentence.
- samples_by_character (Dict): The dictionary of emnist images of each character.
-
- Returns:
- List[np.ndarray]: A list of emnist images of the string.
-
- """
- zero_image = np.zeros((28, 28), np.uint8)
- sample_image_by_character = {}
- for character in string:
- if character in sample_image_by_character:
- continue
- samples = samples_by_character[character]
- sample = samples[np.random.choice(len(samples))] if samples else zero_image
- sample_image_by_character[character] = sample.reshape(28, 28).swapaxes(0, 1)
- return [sample_image_by_character[character] for character in string]
-
-
-def construct_image_from_string(
- string: str, samples_by_character: Dict, min_overlap: float, max_overlap: float
-) -> np.ndarray:
- """Concatenates images of the characters in the string.
-
- The concatination is made with randomly selected overlap so that some portion of the character will overlap.
-
- Args:
- string (str): The word or sentence.
- samples_by_character (Dict): The dictionary of emnist images of each character.
- min_overlap (float): Minimum amount of overlap between Emnist images.
- max_overlap (float): Maximum amount of overlap between Emnist images.
-
- Returns:
- np.ndarray: The Emnist image of the string.
-
- """
- overlap = np.random.uniform(min_overlap, max_overlap)
- sampled_images = select_letter_samples_for_string(string, samples_by_character)
- length = len(sampled_images)
- height, width = sampled_images[0].shape
- next_overlap_width = width - int(overlap * width)
- concatenated_image = np.zeros((height, width * length), np.uint8)
- x = 0
- for image in sampled_images:
- concatenated_image[:, x : (x + width)] += image
- x += next_overlap_width
-
- if concatenated_image.shape[-1] > MAX_WIDTH:
- concatenated_image = Tensor(concatenated_image).unsqueeze(0)
- concatenated_image = F.interpolate(
- concatenated_image, size=MAX_WIDTH, mode="nearest"
- )
- concatenated_image = concatenated_image.squeeze(0).numpy()
-
- return np.minimum(255, concatenated_image)
-
-
-def create_dataset_of_images(
- length: int,
- samples_by_character: Dict,
- sentence_generator: SentenceGenerator,
- min_overlap: float,
- max_overlap: float,
-) -> Tuple[np.ndarray, List[str]]:
- """Creates a dataset with images and labels from strings generated from the SentenceGenerator.
-
- Args:
- length (int): The number of characters for each string.
- samples_by_character (Dict): The dictionary of emnist images of each character.
- sentence_generator (SentenceGenerator): A SentenceGenerator objest.
- min_overlap (float): Minimum amount of overlap between Emnist images.
- max_overlap (float): Maximum amount of overlap between Emnist images.
-
- Returns:
- Tuple[np.ndarray, List[str]]: A list of Emnist images and a list of the strings (labels).
-
- Raises:
- RuntimeError: If the sentence generator is not able to generate a string.
-
- """
- sample_label = sentence_generator.generate()
- sample_image = construct_image_from_string(sample_label, samples_by_character, 0, 0)
- images = np.zeros((length, sample_image.shape[0], sample_image.shape[1]), np.uint8)
- labels = []
- for n in range(length):
- label = None
- # Try several times to generate before actually throwing an error.
- for _ in range(10):
- try:
- label = sentence_generator.generate()
- break
- except Exception: # pylint: disable=broad-except
- pass
- if label is None:
- raise RuntimeError("Was not able to generate a valid string.")
- images[n] = construct_image_from_string(
- label, samples_by_character, min_overlap, max_overlap
- )
- labels.append(label)
- return images, labels
-
-
-def convert_strings_to_categorical_labels(
- labels: List[str], mapping: Dict
-) -> np.ndarray:
- """Translates a string of characters in to a target array of class int."""
- return np.array([[mapping[c] for c in label] for label in labels])
-
-
-@click.command()
-@click.option(
- "--max_length", type=int, default=34, help="Number of characters in a sentence."
-)
-@click.option(
- "--min_overlap", type=float, default=0.0, help="Min overlap between characters."
-)
-@click.option(
- "--max_overlap", type=float, default=0.33, help="Max overlap between characters."
-)
-@click.option("--num_train", type=int, default=10_000, help="Number of train examples.")
-@click.option("--num_test", type=int, default=1_000, help="Number of test examples.")
-def create_datasets(
- max_length: int = 34,
- min_overlap: float = 0,
- max_overlap: float = 0.33,
- num_train: int = 10000,
- num_test: int = 1000,
-) -> None:
- """Creates a training an validation dataset of Emnist lines."""
- num_samples = [num_train, num_test]
- for num, train in zip(num_samples, [True, False]):
- emnist_lines = EmnistLinesDataset(
- train=train,
- max_length=max_length,
- min_overlap=min_overlap,
- max_overlap=max_overlap,
- num_samples=num,
- )
- emnist_lines.load_or_generate_data()
-
-
-if __name__ == "__main__":
- create_datasets()
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py
deleted file mode 100644
index f4a869d..0000000
--- a/src/text_recognizer/datasets/iam_dataset.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
-import os
-from typing import Any, Dict, List
-import zipfile
-
-from boltons.cacheutils import cachedproperty
-import defusedxml.ElementTree as ET
-from loguru import logger
-import toml
-
-from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
-
-RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
-
-DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
-LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
-
-
-class IamDataset:
- """IAM dataset.
-
- "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
- which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
- From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
-
- The data split we will use is
- IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
- The validation set has been merged into the train set.
- The train set has 7,101 lines from 326 writers.
- The test set has 1,861 lines from 128 writers.
- The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
-
- """
-
- def __init__(self) -> None:
- self.metadata = toml.load(METADATA_FILENAME)
-
- def load_or_generate_data(self) -> None:
- """Downloads IAM dataset if xml files does not exist."""
- if not self.xml_filenames:
- self._download_iam()
-
- @property
- def xml_filenames(self) -> List:
- """List of xml filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
-
- @property
- def form_filenames(self) -> List:
- """List of forms filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
-
- def _download_iam(self) -> None:
- curdir = os.getcwd()
- os.chdir(RAW_DATA_DIRNAME)
- _download_raw_dataset(self.metadata)
- _extract_raw_dataset(self.metadata)
- os.chdir(curdir)
-
- @property
- def form_filenames_by_id(self) -> Dict:
- """Creates a dictionary with filenames as keys and forms as values."""
- return {filename.stem: filename for filename in self.form_filenames}
-
- @cachedproperty
- def line_strings_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of line texts in it."""
- return {
- filename.stem: _get_line_strings_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- @cachedproperty
- def line_regions_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
- return {
- filename.stem: _get_line_regions_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- def __repr__(self) -> str:
- """Print info about dataset."""
- return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
-
-
-def _extract_raw_dataset(metadata: Dict) -> None:
- logger.info("Extracting IAM data.")
- with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
- zip_file.extractall()
-
-
-def _get_line_strings_from_xml_file(filename: str) -> List[str]:
- """Get the text content of each line. Note that we replace &quot; with "."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
-
-
-def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
- """Get the line region dict for each line."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
-
-
-def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
- """Extracts coordinates for each line of text."""
- # TODO: fix input!
- word_elements = xml_line.findall("word/cmp")
- x1s = [int(el.attrib["x"]) for el in word_elements]
- y1s = [int(el.attrib["y"]) for el in word_elements]
- x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
- y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
- return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- }
-
-
-def main() -> None:
- """Initializes the dataset and print info about the dataset."""
- dataset = IamDataset()
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
deleted file mode 100644
index 1cb84bd..0000000
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""IamLinesDataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import h5py
-from loguru import logger
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
-PROCESSED_DATA_URL = (
- "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
-)
-
-
-class IamLinesDataset(Dataset):
- """IAM lines datasets for handwritten text lines."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- self.pad_token = "_" if pad_token is None else pad_token
-
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- init_token=init_token,
- pad_token=pad_token,
- eos_token=eos_token,
- lower=lower,
- )
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self.data.shape[1:] if self.data is not None else None
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return (
- self.targets.shape[1:] + (self.num_classes,)
- if self.targets is not None
- else None
- )
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- if not PROCESSED_DATA_FILENAME.exists():
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info("Downloading IAM lines...")
- download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self._data = f[f"x_{self.split}"][:]
- self._targets = f[f"y_{self.split}"][:]
- self._subsample()
-
- def __repr__(self) -> str:
- """Print info about the dataset."""
- return (
- "IAM Lines Dataset\n" # pylint: disable=no-member
- f"Number classes: {self.num_classes}\n"
- f"Mapping: {self.mapper.mapping}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
deleted file mode 100644
index 8ba5142..0000000
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ /dev/null
@@ -1,291 +0,0 @@
-"""IamParagraphsDataset class and functions for data processing."""
-import random
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import click
-import cv2
-import h5py
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer import util
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
-DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
-CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
-GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
-
-PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
-TEST_FRACTION = 0.2
-SEED = 4711
-
-
-class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- )
- # Load Iam dataset.
- self.iam_dataset = IamDataset()
-
- self.num_classes = 3
- self._input_shape = (256, 256)
- self._output_shape = self._input_shape + (self.num_classes,)
- self._ids = None
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- seed = np.random.randint(SEED)
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.transform:
- data = self.transform(data)
-
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets.long()
-
- @property
- def ids(self) -> Tensor:
- """Ids of the dataset."""
- return self._ids
-
- def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
- """Get data target pair from id."""
- ind = self.ids.index(id_)
- return self.data[ind], self.targets[ind]
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
- num_targets = len(self.iam_dataset.line_regions_by_id)
-
- if num_actual < num_targets - 2:
- self._process_iam_paragraphs()
-
- self._data, self._targets, self._ids = _load_iam_paragraphs()
- self._get_random_split()
- self._subsample()
-
- def _get_random_split(self) -> None:
- np.random.seed(SEED)
- num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
- indices = np.random.permutation(self.data.shape[0])
- train_indices, test_indices = indices[:num_train], indices[num_train:]
- if self.train:
- self._data = self.data[train_indices]
- self._targets = self.targets[train_indices]
- else:
- self._data = self.data[test_indices]
- self._targets = self.targets[test_indices]
-
- def _process_iam_paragraphs(self) -> None:
- """Crop the part with the text.
-
- For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
- self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
- corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
- """
- crop_dims = self._decide_on_crop_dims()
- CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- GT_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info(
- f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
- )
- for filename in self.iam_dataset.form_filenames:
- id_ = filename.stem
- line_region = self.iam_dataset.line_regions_by_id[id_]
- _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
-
- def _decide_on_crop_dims(self) -> Tuple[int, int]:
- """Decide on the dimensions to crop out of the form image.
-
- Since image width is larger than a comfortable crop around the longest paragraph,
- we will make the crop a square form factor.
- And since the found dimensions 610x610 are pretty close to 512x512,
- we might as well resize crops and make it exactly that, which lets us
- do all kinds of power-of-2 pooling and upsampling should we choose to.
-
- Returns:
- Tuple[int, int]: A tuple of crop dimensions.
-
- Raises:
- RuntimeError: When max crop height is larger than max crop width.
-
- """
-
- sample_form_filename = self.iam_dataset.form_filenames[0]
- sample_image = util.read_image(sample_form_filename, grayscale=True)
- max_crop_width = sample_image.shape[1]
- max_crop_height = _get_max_paragraph_crop_height(
- self.iam_dataset.line_regions_by_id
- )
- if not max_crop_height <= max_crop_width:
- raise RuntimeError(
- f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
- )
-
- crop_dims = (max_crop_width, max_crop_width)
- logger.info(
- f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
- )
- logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
- return crop_dims
-
- def __repr__(self) -> str:
- """Return info about the dataset."""
- return (
- "IAM Paragraph Dataset\n" # pylint: disable=no-member
- f"Num classes: {self.num_classes}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
-
-def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
- heights = []
- for regions in line_regions_by_id.values():
- min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- heights.append(height)
- return max(heights)
-
-
-def _crop_paragraph_image(
- filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
-) -> None:
- image = util.read_image(filename, grayscale=True)
-
- min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- crop_height = crop_dims[0]
- buffer = (crop_height - height) // 2
-
- # Generate image crop.
- image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
- try:
- image_crop[buffer : buffer + height] = image[min_y1:max_y2]
- except Exception as e: # pylint: disable=broad-except
- logger.error(f"Rescued {filename}: {e}")
- return
-
- # Generate ground truth.
- gt_image = np.zeros_like(image_crop, dtype=np.uint8)
- for index, region in enumerate(line_regions):
- gt_image[
- (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
- region["x1"] : region["x2"],
- ] = (index % 2 + 1)
-
- # Generate image for debugging.
- import matplotlib.pyplot as plt
-
- cmap = plt.get_cmap("Set1")
- image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
- for index, region in enumerate(line_regions):
- color = [255 * _ for _ in cmap(index)[:-1]]
- cv2.rectangle(
- image_crop_for_debug,
- (region["x1"], region["y1"] - min_y1 + buffer),
- (region["x2"], region["y2"] - min_y1 + buffer),
- color,
- 3,
- )
- image_crop_for_debug = cv2.resize(
- image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
- )
- util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
- util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
- util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
-
-
-def _load_iam_paragraphs() -> None:
- logger.info("Loading IAM paragraph crops and ground truth from image files...")
- images = []
- gt_images = []
- ids = []
- for filename in CROPS_DIRNAME.glob("*.jpg"):
- id_ = filename.stem
- image = util.read_image(filename, grayscale=True)
- image = 1.0 - image / 255
-
- gt_filename = GT_DIRNAME / f"{id_}.png"
- gt_image = util.read_image(gt_filename, grayscale=True)
-
- images.append(image)
- gt_images.append(gt_image)
- ids.append(id_)
- images = np.array(images).astype(np.float32)
- gt_images = np.array(gt_images).astype(np.uint8)
- ids = np.array(ids)
- return images, gt_images, ids
-
-
-@click.command()
-@click.option(
- "--subsample_fraction",
- type=float,
- default=None,
- help="The subsampling factor of the dataset.",
-)
-def main(subsample_fraction: float) -> None:
- """Load dataset and print info."""
- logger.info("Creating train set...")
- dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
- logger.info("Creating test set...")
- dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
deleted file mode 100644
index a93eb00..0000000
--- a/src/text_recognizer/datasets/iam_preprocessor.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""Preprocessor for extracting word letters from the IAM dataset.
-
-The code is mostly stolen from:
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
-"""
-
-import collections
-import itertools
-from pathlib import Path
-import re
-from typing import List, Optional, Union
-
-import click
-from loguru import logger
-import torch
-
-
-def load_metadata(
- data_dir: Path, wordsep: str, use_words: bool = False
-) -> collections.defaultdict:
- """Loads IAM metadata and returns it as a dictionary."""
- forms = collections.defaultdict(list)
- filename = "words.txt" if use_words else "lines.txt"
-
- with open(data_dir / "ascii" / filename, "r") as f:
- lines = (line.strip().split() for line in f if line[0] != "#")
- for line in lines:
- # Skip word segmentation errors.
- if use_words and line[1] == "err":
- continue
- text = " ".join(line[8:])
-
- # Remove garbage tokens:
- text = text.replace("#", "")
-
- # Swap word sep form | to wordsep
- text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
- form_key = "-".join(line[0].split("-")[:2])
- line_key = "-".join(line[0].split("-")[:3])
- box_idx = 4 - use_words
- box = tuple(int(val) for val in line[box_idx : box_idx + 4])
- forms[form_key].append({"key": line_key, "box": box, "text": text})
- return forms
-
-
-class Preprocessor:
- """A preprocessor for the IAM dataset."""
-
- # TODO: add lower case only to when generating...
-
- def __init__(
- self,
- data_dir: Union[str, Path],
- num_features: int,
- tokens_path: Optional[Union[str, Path]] = None,
- lexicon_path: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- self.wordsep = "▁"
- self._use_word = use_words
- self._prepend_wordsep = prepend_wordsep
-
- self.data_dir = Path(data_dir)
-
- self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
-
- # Load the set of graphemes:
- graphemes = set()
- for _, form in self.forms.items():
- for line in form:
- graphemes.update(line["text"].lower())
- self.graphemes = sorted(graphemes)
-
- # Build the token-to-index and index-to-token maps.
- if tokens_path is not None:
- with open(tokens_path, "r") as f:
- self.tokens = [line.strip() for line in f]
- else:
- self.tokens = self.graphemes
-
- if lexicon_path is not None:
- with open(lexicon_path, "r") as f:
- lexicon = (line.strip().split() for line in f)
- lexicon = {line[0]: line[1:] for line in lexicon}
- self.lexicon = lexicon
- else:
- self.lexicon = None
-
- self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
- self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
- self.num_features = num_features
- self.text = []
-
- @property
- def num_tokens(self) -> int:
- """Returns the number or tokens."""
- return len(self.tokens)
-
- @property
- def use_words(self) -> bool:
- """If words are used."""
- return self._use_word
-
- def extract_train_text(self) -> None:
- """Extracts training text."""
- keys = []
- with open(self.data_dir / "task" / "trainset.txt") as f:
- keys.extend((line.strip() for line in f))
-
- for _, examples in self.forms.items():
- for example in examples:
- if example["key"] not in keys:
- continue
- self.text.append(example["text"].lower())
-
- def to_index(self, line: str) -> torch.LongTensor:
- """Converts text to a tensor of indices."""
- token_to_index = self.graphemes_to_index
- if self.lexicon is not None:
- if len(line) > 0:
- # If the word is not found in the lexicon, fall back to letters.
- line = [
- t
- for w in line.split(self.wordsep)
- for t in self.lexicon.get(w, self.wordsep + w)
- ]
- token_to_index = self.tokens_to_index
- if self._prepend_wordsep:
- line = itertools.chain([self.wordsep], line)
- return torch.LongTensor([token_to_index[t] for t in line])
-
- def to_text(self, indices: List[int]) -> str:
- """Converts indices to text."""
- # Roughly the inverse of `to_index`
- encoding = self.graphemes
- if self.lexicon is not None:
- encoding = self.tokens
- return self._post_process(encoding[i] for i in indices)
-
- def tokens_to_text(self, indices: List[int]) -> str:
- """Converts tokens to text."""
- return self._post_process(self.tokens[i] for i in indices)
-
- def _post_process(self, indices: List[int]) -> str:
- """A list join."""
- return "".join(indices).strip(self.wordsep)
-
-
-@click.command()
-@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
-@click.option(
- "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
-)
-@click.option(
- "--save_text", type=str, default=None, help="Path to save parsed train text"
-)
-@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
-def cli(
- data_dir: Optional[str],
- use_words: bool,
- save_text: Optional[str],
- save_tokens: Optional[str],
-) -> None:
- """CLI for extracting text data from the iam dataset."""
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
-
- preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
- preprocessor.extract_train_text()
-
- processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
-
- if save_text is not None:
- logger.info("Saving training text")
- with open(processed_dir / save_text, "w") as f:
- f.write("\n".join(t for t in preprocessor.text))
-
- if save_tokens is not None:
- logger.info("Saving tokens")
- with open(processed_dir / save_tokens, "w") as f:
- f.write("\n".join(preprocessor.tokens))
-
-
-if __name__ == "__main__":
- cli()
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py
deleted file mode 100644
index dd76652..0000000
--- a/src/text_recognizer/datasets/sentence_generator.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""Downloading the Brown corpus with NLTK for sentence generating."""
-
-import itertools
-import re
-import string
-from typing import Optional
-
-import nltk
-from nltk.corpus.reader.util import ConcatenatedCorpusView
-import numpy as np
-
-from text_recognizer.datasets.util import DATA_DIRNAME
-
-NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
-
-
-class SentenceGenerator:
- """Generates text sentences using the Brown corpus."""
-
- def __init__(self, max_length: Optional[int] = None) -> None:
- """Loads the corpus and sets word start indices."""
- self.corpus = brown_corpus()
- self.word_start_indices = [0] + [
- _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
- ]
- self.max_length = max_length
-
- def generate(self, max_length: Optional[int] = None) -> str:
- """Generates a word or sentences from the Brown corpus.
-
- Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
- max_length with the '_' characters if sentence is shorter.
-
- Args:
- max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
-
- Returns:
- str: A sentence from the Brown corpus.
-
- Raises:
- ValueError: If max_length was not specified at initialization and not given as an argument.
-
- """
- if max_length is None:
- max_length = self.max_length
- if max_length is None:
- raise ValueError(
- "Must provide max_length to this method or when making this object."
- )
-
- index = np.random.randint(0, len(self.word_start_indices) - 1)
- start_index = self.word_start_indices[index]
- end_index_candidates = []
- for index in range(index + 1, len(self.word_start_indices)):
- if self.word_start_indices[index] - start_index > max_length:
- break
- end_index_candidates.append(self.word_start_indices[index])
- end_index = np.random.choice(end_index_candidates)
- sampled_text = self.corpus[start_index:end_index].strip()
- padding = "_" * (max_length - len(sampled_text))
- return sampled_text + padding
-
-
-def brown_corpus() -> str:
- """Returns a single string with the Brown corpus with all punctuations stripped."""
- sentences = load_nltk_brown_corpus()
- corpus = " ".join(itertools.chain.from_iterable(sentences))
- corpus = corpus.translate({ord(c): None for c in string.punctuation})
- corpus = re.sub(" +", " ", corpus)
- return corpus
-
-
-def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
- """Load the Brown corpus using the NLTK library."""
- nltk.data.path.append(NLTK_DATA_DIRNAME)
- try:
- nltk.corpus.brown.sents()
- except LookupError:
- NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
- return nltk.corpus.brown.sents()
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
deleted file mode 100644
index b6a48f5..0000000
--- a/src/text_recognizer/datasets/transforms.py
+++ /dev/null
@@ -1,266 +0,0 @@
-"""Transforms for PyTorch datasets."""
-from abc import abstractmethod
-from pathlib import Path
-import random
-from typing import Any, Optional, Union
-
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-from torchvision import transforms
-from torchvision.transforms import (
- ColorJitter,
- Compose,
- Normalize,
- RandomAffine,
- RandomHorizontalFlip,
- RandomRotation,
- ToPILImage,
- ToTensor,
-)
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class RandomResizeCrop:
- """Image transform with random resize and crop applied.
-
- Stolen from
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
- """
-
- def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
- self.jitter = jitter
- self.ratio = ratio
-
- def __call__(self, img: np.ndarray) -> np.ndarray:
- """Applies random crop and rotation to an image."""
- w, h = img.size
-
- # pad with white:
- img = transforms.functional.pad(img, self.jitter, fill=255)
-
- # crop at random (x, y):
- x = self.jitter + random.randint(-self.jitter, self.jitter)
- y = self.jitter + random.randint(-self.jitter, self.jitter)
-
- # randomize aspect ratio:
- size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
- size = (h, int(size_w))
- img = transforms.functional.resized_crop(img, y, x, h, w, size)
- return img
-
-
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
-class Resize:
- """Resizes a tensor to a specified width."""
-
- def __init__(self, width: int = 952) -> None:
- # The default is 952 because of the IAM dataset.
- self.width = width
-
- def __call__(self, image: Tensor) -> Tensor:
- """Resize tensor in the last dimension."""
- return F.interpolate(image, size=self.width, mode="nearest")
-
-
-class AddTokens:
- """Adds start of sequence and end of sequence tokens to target tensor."""
-
- def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token,
- )
- self.pad_value = self.emnist_mapper(self.pad_token)
- self.eos_value = self.emnist_mapper(self.eos_token)
-
- def __call__(self, target: Tensor) -> Tensor:
- """Adds a sos token to the begining and a eos token to the end of a target sequence."""
- dtype, device = target.dtype, target.device
-
- # Find the where padding starts.
- pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
-
- target[pad_index] = self.eos_value
-
- if self.init_token is not None:
- self.sos_value = self.emnist_mapper(self.init_token)
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
- target = torch.cat([sos, target], dim=0)
-
- return target
-
-
-class ApplyContrast:
- """Sets everything below a threshold to zero, i.e. increase contrast."""
-
- def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
- self.low = low
- self.high = high
-
- def __call__(self, x: Tensor) -> Tensor:
- """Apply mask binary mask to input tensor."""
- mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
- return x * mask
-
-
-class Unsqueeze:
- """Add a dimension to the tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Adds dim."""
- return x.unsqueeze(0)
-
-
-class Squeeze:
- """Removes the first dimension of a tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Removes first dim."""
- return x.squeeze(0)
-
-
-class ToLower:
- """Converts target to lower case."""
-
- def __call__(self, target: Tensor) -> Tensor:
- """Corrects index value in target tensor."""
- device = target.device
- return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
-
-
-class ToCharcters:
- """Converts integers to characters."""
-
- def __init__(
- self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
- ) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=lower,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
- )
-
- def __call__(self, y: Tensor) -> str:
- """Converts a Tensor to a str."""
- return (
- "".join([self.emnist_mapper(int(i)) for i in y])
- .strip("_")
- .replace(" ", "▁")
- )
-
-
-class WordPieces:
- """Abstract transform for word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- self.preprocessor = Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- )
-
- @abstractmethod
- def __call__(self, *args, **kwargs) -> Any:
- """Transforms input."""
- ...
-
-
-class ToWordPieces(WordPieces):
- """Transforms str to word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, line: str) -> Tensor:
- """Transforms str to word pieces."""
- return self.preprocessor.to_index(line)
-
-
-class ToText(WordPieces):
- """Takes word pieces and converts them to text."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, x: Tensor) -> str:
- """Converts tensor to text."""
- return self.preprocessor.to_text(x.tolist())
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
deleted file mode 100644
index da87756..0000000
--- a/src/text_recognizer/datasets/util.py
+++ /dev/null
@@ -1,209 +0,0 @@
-"""Util functions for datasets."""
-import hashlib
-import json
-import os
-from pathlib import Path
-import string
-from typing import Dict, List, Optional, Union
-from urllib.request import urlretrieve
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from tqdm import tqdm
-
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-
-
-def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(
- self,
- pad_token: str,
- init_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- self.lower = lower
-
- self.essentials = self._load_emnist_essentials()
- # Load dataset information.
- self._mapping = dict(self.essentials["mapping"])
- self._augment_emnist_mapping()
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (
- (isinstance(token, np.uint8) or isinstance(token, int))
- or torch.is_tensor(token)
- and int(token) in self.mapping
- ):
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self) -> None:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- if self.lower:
- self._mapping = {
- k: str(v)
- for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
- }
-
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol, and acts as blank symbol as well.
- extra_symbols.append(self.pad_token)
-
- if self.init_token is not None:
- extra_symbols.append(self.init_token)
-
- if self.eos_token is not None:
- extra_symbols.append(self.eos_token)
-
- max_key = max(self.mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- self._mapping = {**self.mapping, **extra_mapping}
-
-
-def compute_sha256(filename: Union[Path, str]) -> str:
- """Returns the SHA256 checksum of a file."""
- with open(filename, "rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
-
-
-class TqdmUpTo(tqdm):
- """TQDM progress bar when downloading files.
-
- From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
-
- """
-
- def update_to(
- self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
- ) -> None:
- """Updates the progress bar.
-
- Args:
- blocks (int): Number of blocks transferred so far. Defaults to 1.
- block_size (int): Size of each block, in tqdm units. Defaults to 1.
- total_size (Optional[int]): Total size in tqdm units. Defaults to None.
- """
- if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
- self.update(blocks * block_size - self.n)
-
-
-def download_url(url: str, filename: str) -> None:
- """Downloads a file from url to filename, with a progress bar."""
- with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
- urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
-
-
-def _download_raw_dataset(metadata: Dict) -> None:
- if os.path.exists(metadata["filename"]):
- return
- logger.info(f"Downloading raw dataset from {metadata['url']}...")
- download_url(metadata["url"], metadata["filename"])
- logger.info("Computing SHA-256...")
- sha256 = compute_sha256(metadata["filename"])
- if sha256 != metadata["sha256"]:
- raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )