summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
commit8248f173132dfb7e47ec62b08e9235990c8626e3 (patch)
tree2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/data
parent74c907a17379688967dc4b3f41a44ba83034f5e0 (diff)
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/__init__.py1
-rw-r--r--text_recognizer/data/base_data_module.py89
-rw-r--r--text_recognizer/data/base_dataset.py73
-rw-r--r--text_recognizer/data/download_utils.py73
-rw-r--r--text_recognizer/data/emnist.py210
-rw-r--r--text_recognizer/data/emnist_essentials.json1
-rw-r--r--text_recognizer/data/emnist_lines.py280
-rw-r--r--text_recognizer/data/iam.py120
-rw-r--r--text_recognizer/data/iam_dataset.py133
-rw-r--r--text_recognizer/data/iam_lines_dataset.py110
-rw-r--r--text_recognizer/data/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/data/iam_preprocessor.py196
-rw-r--r--text_recognizer/data/sentence_generator.py85
-rw-r--r--text_recognizer/data/transforms.py266
-rw-r--r--text_recognizer/data/util.py209
15 files changed, 2137 insertions, 0 deletions
diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py
new file mode 100644
index 0000000..2727b20
--- /dev/null
+++ b/text_recognizer/data/__init__.py
@@ -0,0 +1 @@
+"""Dataset modules."""
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
new file mode 100644
index 0000000..f5e7300
--- /dev/null
+++ b/text_recognizer/data/base_data_module.py
@@ -0,0 +1,89 @@
+"""Base lightning DataModule class."""
+from pathlib import Path
+from typing import Dict
+
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+
+
+def load_and_print_info(data_module_class: type) -> None:
+ """Load EMNISTLines and prints info."""
+ dataset = data_module_class()
+ dataset.prepare_data()
+ dataset.setup()
+ print(dataset)
+
+
+class BaseDataModule(pl.LightningDataModule):
+ """Base PyTorch Lightning DataModule."""
+
+ def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ super().__init__()
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+
+ # Placeholders for subclasses.
+ self.dims = None
+ self.output_dims = None
+ self.mapping = None
+
+ @classmethod
+ def data_dirname(cls) -> Path:
+ """Return the path to the base data directory."""
+ return Path(__file__).resolve().parents[2] / "data"
+
+ def config(self) -> Dict:
+ """Return important settings of the dataset."""
+ return {
+ "input_dim": self.dims,
+ "output_dims": self.output_dims,
+ "mapping": self.mapping,
+ }
+
+ def prepare_data(self) -> None:
+ """Prepare data for training."""
+ pass
+
+ def setup(self, stage: str = None) -> None:
+ """Split into train, val, test, and set dims.
+
+ Should assign `torch Dataset` objects to self.data_train, self.data_val, and
+ optionally self.data_test.
+
+ Args:
+ stage (Any): Variable to set splits.
+
+ """
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+ def train_dataloader(self) -> DataLoader:
+ """Retun DataLoader for train data."""
+ return DataLoader(
+ self.data_train,
+ shuffle=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(
+ self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """Return DataLoader for val data."""
+ return DataLoader(
+ self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ )
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
new file mode 100644
index 0000000..a9e9c24
--- /dev/null
+++ b/text_recognizer/data/base_dataset.py
@@ -0,0 +1,73 @@
+"""Base PyTorch Dataset class."""
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils.data import Dataset
+
+
+class BaseDataset(Dataset):
+ """
+ Base Dataset class that processes data and targets through optional transfroms.
+
+ Args:
+ data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images.
+ targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays.
+ tranform (Callable): Function that takes a datum and applies transforms.
+ target_transform (Callable): Fucntion that takes a target and applies
+ target transforms.
+ """
+
+ def __init__(
+ self,
+ data: Union[Sequence, Tensor],
+ targets: Union[Sequence, Tensor],
+ transform: Callable = None,
+ target_transform: Callable = None,
+ ) -> None:
+ if len(data) != len(targets):
+ raise ValueError("Data and targets must be of equal length.")
+ self.data = data
+ self.targets = targets
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __len__(self) -> int:
+ """Return the length of the dataset."""
+ return len(self.data)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """Return a datum and its target, after processing by transforms.
+
+ Args:
+ index (int): Index of a datum in the dataset.
+
+ Returns:
+ Tuple[Any, Any]: Datum and target pair.
+
+ """
+ datum, target = self.data[index], self.targets[index]
+
+ if self.transform is not None:
+ datum = self.transform(datum)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return datum, target
+
+
+def convert_strings_to_labels(
+ strings: Sequence[str], mapping: Dict[str, int], length: int
+) -> Tensor:
+ """
+ Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
+ and padded wiht the <p> token.
+ """
+ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
+ for i, string in enumerate(strings):
+ tokens = list(string)
+ tokens = ["<s>", *tokens, "</s>"]
+ for j, token in enumerate(tokens):
+ labels[i, j] = mapping[token]
+ return labels
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
new file mode 100644
index 0000000..e3dc68c
--- /dev/null
+++ b/text_recognizer/data/download_utils.py
@@ -0,0 +1,73 @@
+"""Util functions for downloading datasets."""
+import hashlib
+from pathlib import Path
+from typing import Dict, List, Optional
+from urllib.request import urlretrieve
+
+from loguru import logger
+from tqdm import tqdm
+
+
+def _compute_sha256(filename: Path) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with filename.open(mode="rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def _download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
+ """Downloads dataset using a metadata file.
+
+ Args:
+ metadata (Dict): A metadata file of the dataset.
+ dl_dir (Path): Download directory for the dataset.
+
+ Returns:
+ Optional[Path]: Returns filename if dataset is downloaded, None if it already
+ exists.
+
+ Raises:
+ ValueError: If the SHA-256 value is not the same between the dataset and
+ the metadata file.
+
+ """
+ dl_dir.mkdir(parents=True, exist_ok=True)
+ filename = dl_dir / metadata["filename"]
+ if filename.exists():
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ _download_url(metadata["url"], filename)
+ logger.info("Computing the SHA-256...")
+ sha256 = _compute_sha256(filename)
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )
+ return filename
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
new file mode 100644
index 0000000..7f67893
--- /dev/null
+++ b/text_recognizer/data/emnist.py
@@ -0,0 +1,210 @@
+"""EMNIST dataset: downloads it from FSDL aws url if not present."""
+from pathlib import Path
+from typing import Sequence, Tuple
+import json
+import os
+import shutil
+import zipfile
+
+import h5py
+import numpy as np
+from loguru import logger
+import toml
+import torch
+from torch.utils.data import random_split
+from torchvision import transforms
+
+from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
+from text_recognizer.data.download_utils import download_dataset
+
+
+SEED = 4711
+NUM_SPECIAL_TOKENS = 4
+SAMPLE_TO_BALANCE = True
+
+RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+
+
+class EMNIST(BaseDataModule):
+ """
+ "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
+ and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
+ From https://www.nist.gov/itl/iad/image-group/emnist-dataset
+
+ The data split we will use is
+ EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
+ """
+
+ def __init__(
+ self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+ if not ESSENTIALS_FILENAME.exists():
+ _download_and_process_emnist()
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ self.train_fraction = train_fraction
+ self.mapping = list(essentials["characters"])
+ self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+ self.transform = transforms.Compose([transforms.ToTensor()])
+ self.dims = (1, *essentials["input_shape"])
+ self.output_dims = (1,)
+
+ def prepare_data(self) -> None:
+ if not PROCESSED_DATA_FILENAME.exists():
+ _download_and_process_emnist()
+
+ def setup(self, stage: str = None) -> None:
+ if stage == "fit" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.x_train = f["x_train"][:]
+ self.y_train = f["y_train"][:].squeeze().astype(int)
+
+ dataset_train = BaseDataset(
+ self.x_train, self.y_train, transform=self.transform
+ )
+ train_size = int(self.train_fraction * len(dataset_train))
+ val_size = len(dataset_train) - train_size
+ self.data_train, self.data_val = random_split(
+ dataset_train, [train_size, val_size], generator=torch.Generator()
+ )
+
+ if stage == "test" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.x_test = f["x_test"][:]
+ self.y_test = f["y_test"][:].squeeze().astype(int)
+ self.data_test = BaseDataset(
+ self.x_test, self.y_test, transform=self.transform
+ )
+
+ def __repr__(self) -> str:
+ basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ datum, target = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n"
+ f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n"
+ )
+
+ return basic + data
+
+
+def _download_and_process_emnist() -> None:
+ metadata = toml.load(METADATA_FILENAME)
+ download_dataset(metadata, DL_DATA_DIRNAME)
+ _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
+
+
+def _process_raw_dataset(filename: str, dirname: Path) -> None:
+ logger.info("Unzipping EMNIST...")
+ curdir = os.getcwd()
+ os.chdir(dirname)
+ content = zipfile.ZipFile(filename, "r")
+ content.extract("matlab/emnist-byclass.mat")
+
+ from scipy.io import loadmat
+
+ logger.info("Loading training data from .mat file")
+ data = loadmat("matlab/emnist-byclass.mat")
+ x_train = (
+ data["dataset"]["train"][0, 0]["images"][0, 0]
+ .reshape(-1, 28, 28)
+ .swapaxes(1, 2)
+ )
+ y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ x_test = (
+ data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ )
+ y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+
+ if SAMPLE_TO_BALANCE:
+ logger.info("Balancing classes to reduce amount of data")
+ x_train, y_train = _sample_to_balance(x_train, y_train)
+ x_test, y_test = _sample_to_balance(x_test, y_test)
+
+ logger.info("Saving to HDF5 in a compressed format...")
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
+ f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
+ f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
+ f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
+ f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
+
+ logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
+ characters = _augment_emnist_characters(mapping.values())
+ essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
+
+ with ESSENTIALS_FILENAME.open(mode="w") as f:
+ json.dump(essentials, f)
+
+ logger.info("Cleaning up...")
+ shutil.rmtree("matlab")
+ os.chdir(curdir)
+
+
+def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Balances the dataset by taking the mean number of instances per class."""
+ np.random.seed(SEED)
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ indices = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(indices, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ return x_sampled, y_sampled
+
+
+def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
+ """Augment the mapping with extra symbols."""
+ # Extra characters from the IAM dataset.
+ iam_characters = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # Also add special tokens for:
+ # - CTC blank token at index 0
+ # - Start token at index 1
+ # - End token at index 2
+ # - Padding token at index 3
+ # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this!
+ return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
+
+
+def download_emnist() -> None:
+ """Download dataset from internet, if it does not exists, and displays info."""
+ load_and_print_info(EMNIST)
diff --git a/text_recognizer/data/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json
new file mode 100644
index 0000000..3f46a73
--- /dev/null
+++ b/text_recognizer/data/emnist_essentials.json
@@ -0,0 +1 @@
+{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
new file mode 100644
index 0000000..6c14add
--- /dev/null
+++ b/text_recognizer/data/emnist_lines.py
@@ -0,0 +1,280 @@
+"""Dataset of generated text from EMNIST characters."""
+from collections import defaultdict
+from pathlib import Path
+from typing import Callable, Dict, Tuple, Sequence
+
+import h5py
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels
+from text_recognizer.data.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
+from text_recognizer.data.emnist import EMNIST
+from text_recognizer.data.sentence_generator import SentenceGenerator
+
+
+DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
+ESSENTIALS_FILENAME = (
+ Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
+)
+
+SEED = 4711
+IMAGE_HEIGHT = 56
+IMAGE_WIDTH = 1024
+IMAGE_X_PADDING = 28
+MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
+
+
+class EMNISTLines(BaseDataModule):
+ """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
+
+ def __init__(
+ self,
+ augment: bool = True,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ max_length: int = 32,
+ min_overlap: float = 0.0,
+ max_overlap: float = 0.33,
+ num_train: int = 10_000,
+ num_val: int = 2_000,
+ num_test: int = 2_000,
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+
+ self.augment = augment
+ self.max_length = max_length
+ self.min_overlap = min_overlap
+ self.max_overlap = max_overlap
+ self.num_train = num_train
+ self.num_val = num_val
+ self.num_test = num_test
+
+ self.emnist = EMNIST()
+ self.mapping = self.emnist.mapping
+ max_width = (
+ int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
+ + IMAGE_X_PADDING
+ )
+
+ if max_width >= IMAGE_WIDTH:
+ raise ValueError(
+ f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
+ )
+
+ self.dims = (
+ self.emnist.dims[0],
+ IMAGE_HEIGHT,
+ IMAGE_WIDTH
+ )
+
+ if self.max_length >= MAX_OUTPUT_LENGTH:
+ raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
+
+ self.output_dims = (MAX_OUTPUT_LENGTH, 1)
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+
+ @property
+ def data_filename(self) -> Path:
+ """Return name of dataset."""
+ return (
+ DATA_DIRNAME / (f"ml_{self.max_length}_"
+ f"o{self.min_overlap:f}_{self.max_overlap:f}_"
+ f"ntr{self.num_train}_"
+ f"ntv{self.num_val}_"
+ f"nte{self.num_test}.h5")
+ )
+
+ def prepare_data(self) -> None:
+ if self.data_filename.exists():
+ return
+ np.random.seed(SEED)
+ self._generate_data("train")
+ self._generate_data("val")
+ self._generate_data("test")
+
+ def setup(self, stage: str = None) -> None:
+ logger.info("EMNISTLinesDataset loading data from HDF5...")
+ if stage == "fit" or stage is None:
+ print(self.data_filename)
+ with h5py.File(self.data_filename, "r") as f:
+ x_train = f["x_train"][:]
+ y_train = torch.LongTensor(f["y_train"][:])
+ x_val = f["x_val"][:]
+ y_val = torch.LongTensor(f["y_val"][:])
+
+ self.data_train = BaseDataset(
+ x_train, y_train, transform=_get_transform(augment=self.augment)
+ )
+ self.data_val = BaseDataset(
+ x_val, y_val, transform=_get_transform(augment=self.augment)
+ )
+
+ if stage == "test" or stage is None:
+ with h5py.File(self.data_filename, "r") as f:
+ x_test = f["x_test"][:]
+ y_test = torch.LongTensor(f["y_test"][:])
+
+ self.data_test = BaseDataset(
+ x_test, y_test, transform=_get_transform(augment=False)
+ )
+
+ def __repr__(self) -> str:
+ """Return str about dataset."""
+ basic = (
+ "EMNISTLines2 Dataset\n" # pylint: disable=no-member
+ f"Min overlap: {self.min_overlap}\n"
+ f"Max overlap: {self.max_overlap}\n"
+ f"Num classes: {len(self.mapping)}\n"
+ f"Dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ )
+ return basic + data
+
+ def _generate_data(self, split: str) -> None:
+ logger.info(f"EMNISTLines generating data for {split}...")
+ sentence_generator = SentenceGenerator(
+ self.max_length - 2
+ ) # Subtract by 2 because start/end token
+
+ emnist = self.emnist
+ emnist.prepare_data()
+ emnist.setup()
+
+ if split == "train":
+ samples_by_char = _get_samples_by_char(
+ emnist.x_train, emnist.y_train, emnist.mapping
+ )
+ num = self.num_train
+ elif split == "val":
+ samples_by_char = _get_samples_by_char(
+ emnist.x_train, emnist.y_train, emnist.mapping
+ )
+ num = self.num_val
+ else:
+ samples_by_char = _get_samples_by_char(
+ emnist.x_test, emnist.y_test, emnist.mapping
+ )
+ num = self.num_test
+
+ DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(self.data_filename, "a") as f:
+ x, y = _create_dataset_of_images(
+ num,
+ samples_by_char,
+ sentence_generator,
+ self.min_overlap,
+ self.max_overlap,
+ self.dims,
+ )
+ y = convert_strings_to_labels(
+ y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ )
+ f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
+ f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
+
+
+def _get_samples_by_char(
+ samples: np.ndarray, labels: np.ndarray, mapping: Dict
+) -> defaultdict:
+ samples_by_char = defaultdict(list)
+ for sample, label in zip(samples, labels):
+ samples_by_char[mapping[label]].append(sample)
+ return samples_by_char
+
+
+def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
+ null_image = torch.zeros((28, 28), dtype=torch.uint8)
+ sample_image_by_char = {}
+ for char in string:
+ if char in sample_image_by_char:
+ continue
+ samples = samples_by_char[char]
+ sample = samples[np.random.choice(len(samples))] if samples else null_image
+ sample_image_by_char[char] = sample.reshape(28, 28)
+ return [sample_image_by_char[char] for char in string]
+
+
+def _construct_image_from_string(
+ string: str,
+ samples_by_char: defaultdict,
+ min_overlap: float,
+ max_overlap: float,
+ width: int,
+) -> torch.Tensor:
+ overlap = np.random.uniform(min_overlap, max_overlap)
+ sampled_images = _select_letter_samples_for_string(string, samples_by_char)
+ N = len(sampled_images)
+ H, W = sampled_images[0].shape
+ next_overlap_width = W - int(overlap * W)
+ concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
+ x = IMAGE_X_PADDING
+ for image in sampled_images:
+ concatenated_image[:, x : (x + W)] += image
+ x += next_overlap_width
+ return torch.minimum(torch.Tensor([255]), concatenated_image)
+
+
+def _create_dataset_of_images(
+ num_samples: int,
+ samples_by_char: defaultdict,
+ sentence_generator: SentenceGenerator,
+ min_overlap: float,
+ max_overlap: float,
+ dims: Tuple,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
+ labels = []
+ for n in range(num_samples):
+ label = sentence_generator.generate()
+ crop = _construct_image_from_string(
+ label, samples_by_char, min_overlap, max_overlap, dims[-1]
+ )
+ height = crop.shape[0]
+ y = (IMAGE_HEIGHT - height) // 2
+ images[n, y : (y + height), :] = crop
+ labels.append(label)
+ return images, labels
+
+
+def _get_transform(augment: bool = False) -> Callable:
+ if not augment:
+ return transforms.Compose([transforms.ToTensor()])
+ return transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.ColorJitter(brightness=(0.5, 1.0)),
+ transforms.RandomAffine(
+ degrees=3,
+ translate=(0.0, 0.05),
+ scale=(0.4, 1.1),
+ shear=(-40, 50),
+ interpolation=InterpolationMode.BILINEAR,
+ fill=0,
+ ),
+ ]
+ )
+
+
+def generate_emnist_lines() -> None:
+ """Generates a synthetic handwritten dataset and displays info,"""
+ load_and_print_info(EMNISTLines)
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
new file mode 100644
index 0000000..fcfe9a7
--- /dev/null
+++ b/text_recognizer/data/iam.py
@@ -0,0 +1,120 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from pathlib import Path
+from typing import Any, Dict, List
+import xml.etree.ElementTree as ElementTree
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+from loguru import logger
+from PIL import Image
+import toml
+
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.download_utils import download_dataset
+
+
+RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam"
+EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+
+
+class IAM(BaseDataModule):
+ """
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels.
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+ """
+
+ def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ super().__init__(batch_size, num_workers)
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def prepare_data(self) -> None:
+ if self.xml_filenames:
+ return
+ filename = download_dataset(self.metadata, DL_DATA_DIRNAME)
+ _extract_raw_dataset(filename, DL_DATA_DIRNAME)
+
+ @property
+ def xml_filenames(self) -> List[Path]:
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List[Path]:
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ @property
+ def form_filenames_by_id(self) -> Dict[str, Path]:
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @property
+ def split_by_id(self) -> Dict[str, str]:
+ return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict[str, List[str]]:
+ """Return a dict from name of IAM form to list of line texts in it."""
+ return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames}
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]:
+ """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames}
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return ("IAM Dataset\n"
+ f"Num forms total: {len(self.xml_filenames)}\n"
+ f"Num in test set: {len(self.metadata['test_ids'])}\n")
+
+
+def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
+ logger.info("Extracting IAM data...")
+ curdir = os.getcwd()
+ os.chdir(dirname)
+ with zipfile.ZipFile(filename, "r") as f:
+ f.extractall()
+ os.chdir(curdir)
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace &quot: with "."""
+ xml_root_element = ElementTree.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace("&quot", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get line region dict for each line."""
+ xml_root_element = ElementTree.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_file(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def download_iam() -> None:
+ load_and_print_info(IAM)
diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py
new file mode 100644
index 0000000..a8998b9
--- /dev/null
+++ b/text_recognizer/data/iam_dataset.py
@@ -0,0 +1,133 @@
+"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
+import os
+from typing import Any, Dict, List
+import zipfile
+
+from boltons.cacheutils import cachedproperty
+import defusedxml.ElementTree as ET
+from loguru import logger
+import toml
+
+from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
+
+RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+
+DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
+LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
+
+
+class IamDataset:
+ """IAM dataset.
+
+ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
+ which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
+ From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+
+ The data split we will use is
+ IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
+ The validation set has been merged into the train set.
+ The train set has 7,101 lines from 326 writers.
+ The test set has 1,861 lines from 128 writers.
+ The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
+
+ """
+
+ def __init__(self) -> None:
+ self.metadata = toml.load(METADATA_FILENAME)
+
+ def load_or_generate_data(self) -> None:
+ """Downloads IAM dataset if xml files does not exist."""
+ if not self.xml_filenames:
+ self._download_iam()
+
+ @property
+ def xml_filenames(self) -> List:
+ """List of xml filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+
+ @property
+ def form_filenames(self) -> List:
+ """List of forms filenames."""
+ return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+
+ def _download_iam(self) -> None:
+ curdir = os.getcwd()
+ os.chdir(RAW_DATA_DIRNAME)
+ _download_raw_dataset(self.metadata)
+ _extract_raw_dataset(self.metadata)
+ os.chdir(curdir)
+
+ @property
+ def form_filenames_by_id(self) -> Dict:
+ """Creates a dictionary with filenames as keys and forms as values."""
+ return {filename.stem: filename for filename in self.form_filenames}
+
+ @cachedproperty
+ def line_strings_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of line texts in it."""
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ @cachedproperty
+ def line_regions_by_id(self) -> Dict:
+ """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
+
+ def __repr__(self) -> str:
+ """Print info about dataset."""
+ return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
+
+
+def _extract_raw_dataset(metadata: Dict) -> None:
+ logger.info("Extracting IAM data.")
+ with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
+ zip_file.extractall()
+
+
+def _get_line_strings_from_xml_file(filename: str) -> List[str]:
+ """Get the text content of each line. Note that we replace &quot; with "."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
+
+
+def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
+ """Get the line region dict for each line."""
+ xml_root_element = ET.parse(filename).getroot() # nosec
+ xml_line_elements = xml_root_element.findall("handwritten-part/line")
+ return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
+
+
+def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
+ """Extracts coordinates for each line of text."""
+ # TODO: fix input!
+ word_elements = xml_line.findall("word/cmp")
+ x1s = [int(el.attrib["x"]) for el in word_elements]
+ y1s = [int(el.attrib["y"]) for el in word_elements]
+ x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
+ return {
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
+
+
+def main() -> None:
+ """Initializes the dataset and print info about the dataset."""
+ dataset = IamDataset()
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py
new file mode 100644
index 0000000..1cb84bd
--- /dev/null
+++ b/text_recognizer/data/iam_lines_dataset.py
@@ -0,0 +1,110 @@
+"""IamLinesDataset class."""
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import h5py
+from loguru import logger
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
+PROCESSED_DATA_URL = (
+ "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
+)
+
+
+class IamLinesDataset(Dataset):
+ """IAM lines datasets for handwritten text lines."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ init_token=init_token,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ lower=lower,
+ )
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self.data.shape[1:] if self.data is not None else None
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ if not PROCESSED_DATA_FILENAME.exists():
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info("Downloading IAM lines...")
+ download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
+ self._subsample()
+
+ def __repr__(self) -> str:
+ """Print info about the dataset."""
+ return (
+ "IAM Lines Dataset\n" # pylint: disable=no-member
+ f"Number classes: {self.num_classes}\n"
+ f"Mapping: {self.mapper.mapping}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ if self.transform:
+ data = self.transform(data)
+
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets
diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py
new file mode 100644
index 0000000..8ba5142
--- /dev/null
+++ b/text_recognizer/data/iam_paragraphs_dataset.py
@@ -0,0 +1,291 @@
+"""IamParagraphsDataset class and functions for data processing."""
+import random
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import click
+import cv2
+import h5py
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.transforms import ToTensor
+
+from text_recognizer import util
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.iam_dataset import IamDataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
+
+INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
+DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
+PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
+CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
+GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
+
+PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
+TEST_FRACTION = 0.2
+SEED = 4711
+
+
+class IamParagraphsDataset(Dataset):
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
+
+ def __init__(
+ self,
+ train: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
+ # Load Iam dataset.
+ self.iam_dataset = IamDataset()
+
+ self.num_classes = 3
+ self._input_shape = (256, 256)
+ self._output_shape = self._input_shape + (self.num_classes,)
+ self._ids = None
+
+ def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
+ """Fetches data, target pair of the dataset for a given and index or indices.
+
+ Args:
+ index (Union[int, Tensor]): Either a list or int of indices/index.
+
+ Returns:
+ Tuple[Tensor, Tensor]: Data target pair.
+
+ """
+ if torch.is_tensor(index):
+ index = index.tolist()
+
+ data = self.data[index]
+ targets = self.targets[index]
+
+ seed = np.random.randint(SEED)
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.transform:
+ data = self.transform(data)
+
+ random.seed(seed) # apply this seed to target tranfsorms
+ torch.manual_seed(seed) # needed for torchvision 0.7
+ if self.target_transform:
+ targets = self.target_transform(targets)
+
+ return data, targets.long()
+
+ @property
+ def ids(self) -> Tensor:
+ """Ids of the dataset."""
+ return self._ids
+
+ def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
+ """Get data target pair from id."""
+ ind = self.ids.index(id_)
+ return self.data[ind], self.targets[ind]
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
+ num_targets = len(self.iam_dataset.line_regions_by_id)
+
+ if num_actual < num_targets - 2:
+ self._process_iam_paragraphs()
+
+ self._data, self._targets, self._ids = _load_iam_paragraphs()
+ self._get_random_split()
+ self._subsample()
+
+ def _get_random_split(self) -> None:
+ np.random.seed(SEED)
+ num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
+ indices = np.random.permutation(self.data.shape[0])
+ train_indices, test_indices = indices[:num_train], indices[num_train:]
+ if self.train:
+ self._data = self.data[train_indices]
+ self._targets = self.targets[train_indices]
+ else:
+ self._data = self.data[test_indices]
+ self._targets = self.targets[test_indices]
+
+ def _process_iam_paragraphs(self) -> None:
+ """Crop the part with the text.
+
+ For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
+ self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
+ corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
+ """
+ crop_dims = self._decide_on_crop_dims()
+ CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
+ GT_DIRNAME.mkdir(parents=True, exist_ok=True)
+ logger.info(
+ f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
+ )
+ for filename in self.iam_dataset.form_filenames:
+ id_ = filename.stem
+ line_region = self.iam_dataset.line_regions_by_id[id_]
+ _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
+
+ def _decide_on_crop_dims(self) -> Tuple[int, int]:
+ """Decide on the dimensions to crop out of the form image.
+
+ Since image width is larger than a comfortable crop around the longest paragraph,
+ we will make the crop a square form factor.
+ And since the found dimensions 610x610 are pretty close to 512x512,
+ we might as well resize crops and make it exactly that, which lets us
+ do all kinds of power-of-2 pooling and upsampling should we choose to.
+
+ Returns:
+ Tuple[int, int]: A tuple of crop dimensions.
+
+ Raises:
+ RuntimeError: When max crop height is larger than max crop width.
+
+ """
+
+ sample_form_filename = self.iam_dataset.form_filenames[0]
+ sample_image = util.read_image(sample_form_filename, grayscale=True)
+ max_crop_width = sample_image.shape[1]
+ max_crop_height = _get_max_paragraph_crop_height(
+ self.iam_dataset.line_regions_by_id
+ )
+ if not max_crop_height <= max_crop_width:
+ raise RuntimeError(
+ f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
+ )
+
+ crop_dims = (max_crop_width, max_crop_width)
+ logger.info(
+ f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
+ )
+ logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
+ return crop_dims
+
+ def __repr__(self) -> str:
+ """Return info about the dataset."""
+ return (
+ "IAM Paragraph Dataset\n" # pylint: disable=no-member
+ f"Num classes: {self.num_classes}\n"
+ f"Data: {self.data.shape}\n"
+ f"Targets: {self.targets.shape}\n"
+ )
+
+
+def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
+ heights = []
+ for regions in line_regions_by_id.values():
+ min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ heights.append(height)
+ return max(heights)
+
+
+def _crop_paragraph_image(
+ filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
+) -> None:
+ image = util.read_image(filename, grayscale=True)
+
+ min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
+ max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
+ height = max_y2 - min_y1
+ crop_height = crop_dims[0]
+ buffer = (crop_height - height) // 2
+
+ # Generate image crop.
+ image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
+ try:
+ image_crop[buffer : buffer + height] = image[min_y1:max_y2]
+ except Exception as e: # pylint: disable=broad-except
+ logger.error(f"Rescued {filename}: {e}")
+ return
+
+ # Generate ground truth.
+ gt_image = np.zeros_like(image_crop, dtype=np.uint8)
+ for index, region in enumerate(line_regions):
+ gt_image[
+ (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
+ region["x1"] : region["x2"],
+ ] = (index % 2 + 1)
+
+ # Generate image for debugging.
+ import matplotlib.pyplot as plt
+
+ cmap = plt.get_cmap("Set1")
+ image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
+ for index, region in enumerate(line_regions):
+ color = [255 * _ for _ in cmap(index)[:-1]]
+ cv2.rectangle(
+ image_crop_for_debug,
+ (region["x1"], region["y1"] - min_y1 + buffer),
+ (region["x2"], region["y2"] - min_y1 + buffer),
+ color,
+ 3,
+ )
+ image_crop_for_debug = cv2.resize(
+ image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
+ )
+ util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
+ util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
+
+ gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
+ util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
+
+
+def _load_iam_paragraphs() -> None:
+ logger.info("Loading IAM paragraph crops and ground truth from image files...")
+ images = []
+ gt_images = []
+ ids = []
+ for filename in CROPS_DIRNAME.glob("*.jpg"):
+ id_ = filename.stem
+ image = util.read_image(filename, grayscale=True)
+ image = 1.0 - image / 255
+
+ gt_filename = GT_DIRNAME / f"{id_}.png"
+ gt_image = util.read_image(gt_filename, grayscale=True)
+
+ images.append(image)
+ gt_images.append(gt_image)
+ ids.append(id_)
+ images = np.array(images).astype(np.float32)
+ gt_images = np.array(gt_images).astype(np.uint8)
+ ids = np.array(ids)
+ return images, gt_images, ids
+
+
+@click.command()
+@click.option(
+ "--subsample_fraction",
+ type=float,
+ default=None,
+ help="The subsampling factor of the dataset.",
+)
+def main(subsample_fraction: float) -> None:
+ """Load dataset and print info."""
+ logger.info("Creating train set...")
+ dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+ logger.info("Creating test set...")
+ dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
+ dataset.load_or_generate_data()
+ print(dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
new file mode 100644
index 0000000..a93eb00
--- /dev/null
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -0,0 +1,196 @@
+"""Preprocessor for extracting word letters from the IAM dataset.
+
+The code is mostly stolen from:
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+"""
+
+import collections
+import itertools
+from pathlib import Path
+import re
+from typing import List, Optional, Union
+
+import click
+from loguru import logger
+import torch
+
+
+def load_metadata(
+ data_dir: Path, wordsep: str, use_words: bool = False
+) -> collections.defaultdict:
+ """Loads IAM metadata and returns it as a dictionary."""
+ forms = collections.defaultdict(list)
+ filename = "words.txt" if use_words else "lines.txt"
+
+ with open(data_dir / "ascii" / filename, "r") as f:
+ lines = (line.strip().split() for line in f if line[0] != "#")
+ for line in lines:
+ # Skip word segmentation errors.
+ if use_words and line[1] == "err":
+ continue
+ text = " ".join(line[8:])
+
+ # Remove garbage tokens:
+ text = text.replace("#", "")
+
+ # Swap word sep form | to wordsep
+ text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
+ form_key = "-".join(line[0].split("-")[:2])
+ line_key = "-".join(line[0].split("-")[:3])
+ box_idx = 4 - use_words
+ box = tuple(int(val) for val in line[box_idx : box_idx + 4])
+ forms[form_key].append({"key": line_key, "box": box, "text": text})
+ return forms
+
+
+class Preprocessor:
+ """A preprocessor for the IAM dataset."""
+
+ # TODO: add lower case only to when generating...
+
+ def __init__(
+ self,
+ data_dir: Union[str, Path],
+ num_features: int,
+ tokens_path: Optional[Union[str, Path]] = None,
+ lexicon_path: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ self.wordsep = "▁"
+ self._use_word = use_words
+ self._prepend_wordsep = prepend_wordsep
+
+ self.data_dir = Path(data_dir)
+
+ self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
+
+ # Load the set of graphemes:
+ graphemes = set()
+ for _, form in self.forms.items():
+ for line in form:
+ graphemes.update(line["text"].lower())
+ self.graphemes = sorted(graphemes)
+
+ # Build the token-to-index and index-to-token maps.
+ if tokens_path is not None:
+ with open(tokens_path, "r") as f:
+ self.tokens = [line.strip() for line in f]
+ else:
+ self.tokens = self.graphemes
+
+ if lexicon_path is not None:
+ with open(lexicon_path, "r") as f:
+ lexicon = (line.strip().split() for line in f)
+ lexicon = {line[0]: line[1:] for line in lexicon}
+ self.lexicon = lexicon
+ else:
+ self.lexicon = None
+
+ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
+ self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
+ self.num_features = num_features
+ self.text = []
+
+ @property
+ def num_tokens(self) -> int:
+ """Returns the number or tokens."""
+ return len(self.tokens)
+
+ @property
+ def use_words(self) -> bool:
+ """If words are used."""
+ return self._use_word
+
+ def extract_train_text(self) -> None:
+ """Extracts training text."""
+ keys = []
+ with open(self.data_dir / "task" / "trainset.txt") as f:
+ keys.extend((line.strip() for line in f))
+
+ for _, examples in self.forms.items():
+ for example in examples:
+ if example["key"] not in keys:
+ continue
+ self.text.append(example["text"].lower())
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ token_to_index = self.graphemes_to_index
+ if self.lexicon is not None:
+ if len(line) > 0:
+ # If the word is not found in the lexicon, fall back to letters.
+ line = [
+ t
+ for w in line.split(self.wordsep)
+ for t in self.lexicon.get(w, self.wordsep + w)
+ ]
+ token_to_index = self.tokens_to_index
+ if self._prepend_wordsep:
+ line = itertools.chain([self.wordsep], line)
+ return torch.LongTensor([token_to_index[t] for t in line])
+
+ def to_text(self, indices: List[int]) -> str:
+ """Converts indices to text."""
+ # Roughly the inverse of `to_index`
+ encoding = self.graphemes
+ if self.lexicon is not None:
+ encoding = self.tokens
+ return self._post_process(encoding[i] for i in indices)
+
+ def tokens_to_text(self, indices: List[int]) -> str:
+ """Converts tokens to text."""
+ return self._post_process(self.tokens[i] for i in indices)
+
+ def _post_process(self, indices: List[int]) -> str:
+ """A list join."""
+ return "".join(indices).strip(self.wordsep)
+
+
+@click.command()
+@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
+@click.option(
+ "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
+)
+@click.option(
+ "--save_text", type=str, default=None, help="Path to save parsed train text"
+)
+@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
+def cli(
+ data_dir: Optional[str],
+ use_words: bool,
+ save_text: Optional[str],
+ save_tokens: Optional[str],
+) -> None:
+ """CLI for extracting text data from the iam dataset."""
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+
+ preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
+ preprocessor.extract_train_text()
+
+ processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
+ logger.debug(f"Saving processed files at: {processed_dir}")
+
+ if save_text is not None:
+ logger.info("Saving training text")
+ with open(processed_dir / save_text, "w") as f:
+ f.write("\n".join(t for t in preprocessor.text))
+
+ if save_tokens is not None:
+ logger.info("Saving tokens")
+ with open(processed_dir / save_tokens, "w") as f:
+ f.write("\n".join(preprocessor.tokens))
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py
new file mode 100644
index 0000000..53b781c
--- /dev/null
+++ b/text_recognizer/data/sentence_generator.py
@@ -0,0 +1,85 @@
+"""Downloading the Brown corpus with NLTK for sentence generating."""
+
+import itertools
+import re
+import string
+from typing import Optional
+
+import nltk
+from nltk.corpus.reader.util import ConcatenatedCorpusView
+import numpy as np
+
+from text_recognizer.datasets.util import DATA_DIRNAME
+
+NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk"
+
+
+class SentenceGenerator:
+ """Generates text sentences using the Brown corpus."""
+
+ def __init__(self, max_length: Optional[int] = None) -> None:
+ """Loads the corpus and sets word start indices."""
+ self.corpus = brown_corpus()
+ self.word_start_indices = [0] + [
+ _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
+ ]
+ self.max_length = max_length
+
+ def generate(self, max_length: Optional[int] = None) -> str:
+ """Generates a word or sentences from the Brown corpus.
+
+ Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
+ max_length with the '_' characters if sentence is shorter.
+
+ Args:
+ max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
+
+ Returns:
+ str: A sentence from the Brown corpus.
+
+ Raises:
+ ValueError: If max_length was not specified at initialization and not given as an argument.
+
+ """
+ if max_length is None:
+ max_length = self.max_length
+ if max_length is None:
+ raise ValueError(
+ "Must provide max_length to this method or when making this object."
+ )
+
+ for _ in range(10):
+ try:
+ index = np.random.randint(0, len(self.word_start_indices) - 1)
+ start_index = self.word_start_indices[index]
+ end_index_candidates = []
+ for index in range(index + 1, len(self.word_start_indices)):
+ if self.word_start_indices[index] - start_index > max_length:
+ break
+ end_index_candidates.append(self.word_start_indices[index])
+ end_index = np.random.choice(end_index_candidates)
+ sampled_text = self.corpus[start_index:end_index].strip()
+ return sampled_text
+ except Exception:
+ pass
+ raise RuntimeError("Was not able to generate a valid string")
+
+
+def brown_corpus() -> str:
+ """Returns a single string with the Brown corpus with all punctuations stripped."""
+ sentences = load_nltk_brown_corpus()
+ corpus = " ".join(itertools.chain.from_iterable(sentences))
+ corpus = corpus.translate({ord(c): None for c in string.punctuation})
+ corpus = re.sub(" +", " ", corpus)
+ return corpus
+
+
+def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
+ """Load the Brown corpus using the NLTK library."""
+ nltk.data.path.append(NLTK_DATA_DIRNAME)
+ try:
+ nltk.corpus.brown.sents()
+ except LookupError:
+ NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
+ return nltk.corpus.brown.sents()
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
new file mode 100644
index 0000000..b6a48f5
--- /dev/null
+++ b/text_recognizer/data/transforms.py
@@ -0,0 +1,266 @@
+"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
+import random
+from typing import Any, Optional, Union
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision import transforms
+from torchvision.transforms import (
+ ColorJitter,
+ Compose,
+ Normalize,
+ RandomAffine,
+ RandomHorizontalFlip,
+ RandomRotation,
+ ToPILImage,
+ ToTensor,
+)
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class RandomResizeCrop:
+ """Image transform with random resize and crop applied.
+
+ Stolen from
+
+ https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
+
+ """
+
+ def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
+ self.jitter = jitter
+ self.ratio = ratio
+
+ def __call__(self, img: np.ndarray) -> np.ndarray:
+ """Applies random crop and rotation to an image."""
+ w, h = img.size
+
+ # pad with white:
+ img = transforms.functional.pad(img, self.jitter, fill=255)
+
+ # crop at random (x, y):
+ x = self.jitter + random.randint(-self.jitter, self.jitter)
+ y = self.jitter + random.randint(-self.jitter, self.jitter)
+
+ # randomize aspect ratio:
+ size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
+ size = (h, int(size_w))
+ img = transforms.functional.resized_crop(img, y, x, h, w, size)
+ return img
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
+
+
+class Resize:
+ """Resizes a tensor to a specified width."""
+
+ def __init__(self, width: int = 952) -> None:
+ # The default is 952 because of the IAM dataset.
+ self.width = width
+
+ def __call__(self, image: Tensor) -> Tensor:
+ """Resize tensor in the last dimension."""
+ return F.interpolate(image, size=self.width, mode="nearest")
+
+
+class AddTokens:
+ """Adds start of sequence and end of sequence tokens to target tensor."""
+
+ def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token,
+ )
+ self.pad_value = self.emnist_mapper(self.pad_token)
+ self.eos_value = self.emnist_mapper(self.eos_token)
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Adds a sos token to the begining and a eos token to the end of a target sequence."""
+ dtype, device = target.dtype, target.device
+
+ # Find the where padding starts.
+ pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
+
+ target[pad_index] = self.eos_value
+
+ if self.init_token is not None:
+ self.sos_value = self.emnist_mapper(self.init_token)
+ sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
+ target = torch.cat([sos, target], dim=0)
+
+ return target
+
+
+class ApplyContrast:
+ """Sets everything below a threshold to zero, i.e. increase contrast."""
+
+ def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
+ self.low = low
+ self.high = high
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Apply mask binary mask to input tensor."""
+ mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
+ return x * mask
+
+
+class Unsqueeze:
+ """Add a dimension to the tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Adds dim."""
+ return x.unsqueeze(0)
+
+
+class Squeeze:
+ """Removes the first dimension of a tensor."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Removes first dim."""
+ return x.squeeze(0)
+
+
+class ToLower:
+ """Converts target to lower case."""
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Corrects index value in target tensor."""
+ device = target.device
+ return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())
diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py
new file mode 100644
index 0000000..da87756
--- /dev/null
+++ b/text_recognizer/data/util.py
@@ -0,0 +1,209 @@
+"""Util functions for datasets."""
+import hashlib
+import json
+import os
+from pathlib import Path
+import string
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
+
+from loguru import logger
+import numpy as np
+import torch
+from torch import Tensor
+from torchvision.datasets import EMNIST
+from tqdm import tqdm
+
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+
+
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
+
+
+def download_emnist() -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
+
+
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ lower: bool = False,
+ ) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ self.lower = lower
+
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset information.
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self) -> None:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol, and acts as blank symbol as well.
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
+
+ max_key = max(self.mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ self._mapping = {**self.mapping, **extra_mapping}
+
+
+def compute_sha256(filename: Union[Path, str]) -> str:
+ """Returns the SHA256 checksum of a file."""
+ with open(filename, "rb") as f:
+ return hashlib.sha256(f.read()).hexdigest()
+
+
+class TqdmUpTo(tqdm):
+ """TQDM progress bar when downloading files.
+
+ From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+
+ """
+
+ def update_to(
+ self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
+ ) -> None:
+ """Updates the progress bar.
+
+ Args:
+ blocks (int): Number of blocks transferred so far. Defaults to 1.
+ block_size (int): Size of each block, in tqdm units. Defaults to 1.
+ total_size (Optional[int]): Total size in tqdm units. Defaults to None.
+ """
+ if total_size is not None:
+ self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.update(blocks * block_size - self.n)
+
+
+def download_url(url: str, filename: str) -> None:
+ """Downloads a file from url to filename, with a progress bar."""
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
+ urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
+
+
+def _download_raw_dataset(metadata: Dict) -> None:
+ if os.path.exists(metadata["filename"]):
+ return
+ logger.info(f"Downloading raw dataset from {metadata['url']}...")
+ download_url(metadata["url"], metadata["filename"])
+ logger.info("Computing SHA-256...")
+ sha256 = compute_sha256(metadata["filename"])
+ if sha256 != metadata["sha256"]:
+ raise ValueError(
+ "Downloaded data file SHA-256 does not match that listed in metadata document."
+ )