From 8248f173132dfb7e47ec62b08e9235990c8626e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Mar 2021 22:15:54 +0100 Subject: renamed datasets to data, added iam refactor --- text_recognizer/data/__init__.py | 1 + text_recognizer/data/base_data_module.py | 89 +++++++ text_recognizer/data/base_dataset.py | 73 ++++++ text_recognizer/data/download_utils.py | 73 ++++++ text_recognizer/data/emnist.py | 210 +++++++++++++++ text_recognizer/data/emnist_essentials.json | 1 + text_recognizer/data/emnist_lines.py | 280 ++++++++++++++++++++ text_recognizer/data/iam.py | 120 +++++++++ text_recognizer/data/iam_dataset.py | 133 ++++++++++ text_recognizer/data/iam_lines_dataset.py | 110 ++++++++ text_recognizer/data/iam_paragraphs_dataset.py | 291 +++++++++++++++++++++ text_recognizer/data/iam_preprocessor.py | 196 ++++++++++++++ text_recognizer/data/sentence_generator.py | 85 ++++++ text_recognizer/data/transforms.py | 266 +++++++++++++++++++ text_recognizer/data/util.py | 209 +++++++++++++++ text_recognizer/datasets/__init__.py | 1 - text_recognizer/datasets/base_data_module.py | 89 ------- text_recognizer/datasets/base_dataset.py | 73 ------ text_recognizer/datasets/download_utils.py | 73 ------ text_recognizer/datasets/emnist.py | 210 --------------- text_recognizer/datasets/emnist_essentials.json | 1 - text_recognizer/datasets/emnist_lines.py | 280 -------------------- text_recognizer/datasets/iam_dataset.py | 133 ---------- text_recognizer/datasets/iam_lines_dataset.py | 110 -------- text_recognizer/datasets/iam_paragraphs_dataset.py | 291 --------------------- text_recognizer/datasets/iam_preprocessor.py | 196 -------------- text_recognizer/datasets/sentence_generator.py | 85 ------ text_recognizer/datasets/transforms.py | 266 ------------------- text_recognizer/datasets/util.py | 209 --------------- 29 files changed, 2137 insertions(+), 2017 deletions(-) create mode 100644 text_recognizer/data/__init__.py create mode 100644 text_recognizer/data/base_data_module.py create mode 100644 text_recognizer/data/base_dataset.py create mode 100644 text_recognizer/data/download_utils.py create mode 100644 text_recognizer/data/emnist.py create mode 100644 text_recognizer/data/emnist_essentials.json create mode 100644 text_recognizer/data/emnist_lines.py create mode 100644 text_recognizer/data/iam.py create mode 100644 text_recognizer/data/iam_dataset.py create mode 100644 text_recognizer/data/iam_lines_dataset.py create mode 100644 text_recognizer/data/iam_paragraphs_dataset.py create mode 100644 text_recognizer/data/iam_preprocessor.py create mode 100644 text_recognizer/data/sentence_generator.py create mode 100644 text_recognizer/data/transforms.py create mode 100644 text_recognizer/data/util.py delete mode 100644 text_recognizer/datasets/__init__.py delete mode 100644 text_recognizer/datasets/base_data_module.py delete mode 100644 text_recognizer/datasets/base_dataset.py delete mode 100644 text_recognizer/datasets/download_utils.py delete mode 100644 text_recognizer/datasets/emnist.py delete mode 100644 text_recognizer/datasets/emnist_essentials.json delete mode 100644 text_recognizer/datasets/emnist_lines.py delete mode 100644 text_recognizer/datasets/iam_dataset.py delete mode 100644 text_recognizer/datasets/iam_lines_dataset.py delete mode 100644 text_recognizer/datasets/iam_paragraphs_dataset.py delete mode 100644 text_recognizer/datasets/iam_preprocessor.py delete mode 100644 text_recognizer/datasets/sentence_generator.py delete mode 100644 text_recognizer/datasets/transforms.py delete mode 100644 text_recognizer/datasets/util.py (limited to 'text_recognizer') diff --git a/text_recognizer/data/__init__.py b/text_recognizer/data/__init__.py new file mode 100644 index 0000000..2727b20 --- /dev/null +++ b/text_recognizer/data/__init__.py @@ -0,0 +1 @@ +"""Dataset modules.""" diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py new file mode 100644 index 0000000..f5e7300 --- /dev/null +++ b/text_recognizer/data/base_data_module.py @@ -0,0 +1,89 @@ +"""Base lightning DataModule class.""" +from pathlib import Path +from typing import Dict + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +def load_and_print_info(data_module_class: type) -> None: + """Load EMNISTLines and prints info.""" + dataset = data_module_class() + dataset.prepare_data() + dataset.setup() + print(dataset) + + +class BaseDataModule(pl.LightningDataModule): + """Base PyTorch Lightning DataModule.""" + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + + # Placeholders for subclasses. + self.dims = None + self.output_dims = None + self.mapping = None + + @classmethod + def data_dirname(cls) -> Path: + """Return the path to the base data directory.""" + return Path(__file__).resolve().parents[2] / "data" + + def config(self) -> Dict: + """Return important settings of the dataset.""" + return { + "input_dim": self.dims, + "output_dims": self.output_dims, + "mapping": self.mapping, + } + + def prepare_data(self) -> None: + """Prepare data for training.""" + pass + + def setup(self, stage: str = None) -> None: + """Split into train, val, test, and set dims. + + Should assign `torch Dataset` objects to self.data_train, self.data_val, and + optionally self.data_test. + + Args: + stage (Any): Variable to set splits. + + """ + self.data_train = None + self.data_val = None + self.data_test = None + + def train_dataloader(self) -> DataLoader: + """Retun DataLoader for train data.""" + return DataLoader( + self.data_train, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) + + def val_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader( + self.data_val, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) + + def test_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader( + self.data_test, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py new file mode 100644 index 0000000..a9e9c24 --- /dev/null +++ b/text_recognizer/data/base_dataset.py @@ -0,0 +1,73 @@ +"""Base PyTorch Dataset class.""" +from typing import Any, Callable, Dict, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torch.utils.data import Dataset + + +class BaseDataset(Dataset): + """ + Base Dataset class that processes data and targets through optional transfroms. + + Args: + data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. + targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. + tranform (Callable): Function that takes a datum and applies transforms. + target_transform (Callable): Fucntion that takes a target and applies + target transforms. + """ + + def __init__( + self, + data: Union[Sequence, Tensor], + targets: Union[Sequence, Tensor], + transform: Callable = None, + target_transform: Callable = None, + ) -> None: + if len(data) != len(targets): + raise ValueError("Data and targets must be of equal length.") + self.data = data + self.targets = targets + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Return a datum and its target, after processing by transforms. + + Args: + index (int): Index of a datum in the dataset. + + Returns: + Tuple[Any, Any]: Datum and target pair. + + """ + datum, target = self.data[index], self.targets[index] + + if self.transform is not None: + datum = self.transform(datum) + + if self.target_transform is not None: + target = self.target_transform(target) + + return datum, target + + +def convert_strings_to_labels( + strings: Sequence[str], mapping: Dict[str, int], length: int +) -> Tensor: + """ + Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, + and padded wiht the

token. + """ + labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] + for i, string in enumerate(strings): + tokens = list(string) + tokens = ["", *tokens, ""] + for j, token in enumerate(tokens): + labels[i, j] = mapping[token] + return labels diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py new file mode 100644 index 0000000..e3dc68c --- /dev/null +++ b/text_recognizer/data/download_utils.py @@ -0,0 +1,73 @@ +"""Util functions for downloading datasets.""" +import hashlib +from pathlib import Path +from typing import Dict, List, Optional +from urllib.request import urlretrieve + +from loguru import logger +from tqdm import tqdm + + +def _compute_sha256(filename: Path) -> str: + """Returns the SHA256 checksum of a file.""" + with filename.open(mode="rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def _download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: + """Downloads dataset using a metadata file. + + Args: + metadata (Dict): A metadata file of the dataset. + dl_dir (Path): Download directory for the dataset. + + Returns: + Optional[Path]: Returns filename if dataset is downloaded, None if it already + exists. + + Raises: + ValueError: If the SHA-256 value is not the same between the dataset and + the metadata file. + + """ + dl_dir.mkdir(parents=True, exist_ok=True) + filename = dl_dir / metadata["filename"] + if filename.exists(): + return + logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + _download_url(metadata["url"], filename) + logger.info("Computing the SHA-256...") + sha256 = _compute_sha256(filename) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) + return filename diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py new file mode 100644 index 0000000..7f67893 --- /dev/null +++ b/text_recognizer/data/emnist.py @@ -0,0 +1,210 @@ +"""EMNIST dataset: downloads it from FSDL aws url if not present.""" +from pathlib import Path +from typing import Sequence, Tuple +import json +import os +import shutil +import zipfile + +import h5py +import numpy as np +from loguru import logger +import toml +import torch +from torch.utils.data import random_split +from torchvision import transforms + +from text_recognizer.data.base_dataset import BaseDataset +from text_recognizer.data.base_data_module import ( + BaseDataModule, + load_and_print_info, +) +from text_recognizer.data.download_utils import download_dataset + + +SEED = 4711 +NUM_SPECIAL_TOKENS = 4 +SAMPLE_TO_BALANCE = True + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" + + +class EMNIST(BaseDataModule): + """ + "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 + and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." + From https://www.nist.gov/itl/iad/image-group/emnist-dataset + + The data split we will use is + EMNIST ByClass: 814,255 characters. 62 unbalanced classes. + """ + + def __init__( + self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 + ) -> None: + super().__init__(batch_size, num_workers) + if not ESSENTIALS_FILENAME.exists(): + _download_and_process_emnist() + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + self.train_fraction = train_fraction + self.mapping = list(essentials["characters"]) + self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} + self.data_train = None + self.data_val = None + self.data_test = None + self.transform = transforms.Compose([transforms.ToTensor()]) + self.dims = (1, *essentials["input_shape"]) + self.output_dims = (1,) + + def prepare_data(self) -> None: + if not PROCESSED_DATA_FILENAME.exists(): + _download_and_process_emnist() + + def setup(self, stage: str = None) -> None: + if stage == "fit" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.x_train = f["x_train"][:] + self.y_train = f["y_train"][:].squeeze().astype(int) + + dataset_train = BaseDataset( + self.x_train, self.y_train, transform=self.transform + ) + train_size = int(self.train_fraction * len(dataset_train)) + val_size = len(dataset_train) - train_size + self.data_train, self.data_val = random_split( + dataset_train, [train_size, val_size], generator=torch.Generator() + ) + + if stage == "test" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self.x_test = f["x_test"][:] + self.y_test = f["y_test"][:].squeeze().astype(int) + self.data_test = BaseDataset( + self.x_test, self.y_test, transform=self.transform + ) + + def __repr__(self) -> str: + basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + datum, target = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" + f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" + ) + + return basic + data + + +def _download_and_process_emnist() -> None: + metadata = toml.load(METADATA_FILENAME) + download_dataset(metadata, DL_DATA_DIRNAME) + _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) + + +def _process_raw_dataset(filename: str, dirname: Path) -> None: + logger.info("Unzipping EMNIST...") + curdir = os.getcwd() + os.chdir(dirname) + content = zipfile.ZipFile(filename, "r") + content.extract("matlab/emnist-byclass.mat") + + from scipy.io import loadmat + + logger.info("Loading training data from .mat file") + data = loadmat("matlab/emnist-byclass.mat") + x_train = ( + data["dataset"]["train"][0, 0]["images"][0, 0] + .reshape(-1, 28, 28) + .swapaxes(1, 2) + ) + y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + x_test = ( + data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + ) + y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + + if SAMPLE_TO_BALANCE: + logger.info("Balancing classes to reduce amount of data") + x_train, y_train = _sample_to_balance(x_train, y_train) + x_test, y_test = _sample_to_balance(x_test, y_test) + + logger.info("Saving to HDF5 in a compressed format...") + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: + f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") + f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") + f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") + f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") + + logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} + characters = _augment_emnist_characters(mapping.values()) + essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} + + with ESSENTIALS_FILENAME.open(mode="w") as f: + json.dump(essentials, f) + + logger.info("Cleaning up...") + shutil.rmtree("matlab") + os.chdir(curdir) + + +def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Balances the dataset by taking the mean number of instances per class.""" + np.random.seed(SEED) + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + indices = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + return x_sampled, y_sampled + + +def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: + """Augment the mapping with extra symbols.""" + # Extra characters from the IAM dataset. + iam_characters = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # Also add special tokens for: + # - CTC blank token at index 0 + # - Start token at index 1 + # - End token at index 2 + # - Padding token at index 3 + # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! + return ["", "", "", "

", *characters, *iam_characters] + + +def download_emnist() -> None: + """Download dataset from internet, if it does not exists, and displays info.""" + load_and_print_info(EMNIST) diff --git a/text_recognizer/data/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json new file mode 100644 index 0000000..3f46a73 --- /dev/null +++ b/text_recognizer/data/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py new file mode 100644 index 0000000..6c14add --- /dev/null +++ b/text_recognizer/data/emnist_lines.py @@ -0,0 +1,280 @@ +"""Dataset of generated text from EMNIST characters.""" +from collections import defaultdict +from pathlib import Path +from typing import Callable, Dict, Tuple, Sequence + +import h5py +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels +from text_recognizer.data.base_data_module import ( + BaseDataModule, + load_and_print_info, +) +from text_recognizer.data.emnist import EMNIST +from text_recognizer.data.sentence_generator import SentenceGenerator + + +DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" +ESSENTIALS_FILENAME = ( + Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" +) + +SEED = 4711 +IMAGE_HEIGHT = 56 +IMAGE_WIDTH = 1024 +IMAGE_X_PADDING = 28 +MAX_OUTPUT_LENGTH = 89 # Same as IAMLines + + +class EMNISTLines(BaseDataModule): + """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" + + def __init__( + self, + augment: bool = True, + batch_size: int = 128, + num_workers: int = 0, + max_length: int = 32, + min_overlap: float = 0.0, + max_overlap: float = 0.33, + num_train: int = 10_000, + num_val: int = 2_000, + num_test: int = 2_000, + ) -> None: + super().__init__(batch_size, num_workers) + + self.augment = augment + self.max_length = max_length + self.min_overlap = min_overlap + self.max_overlap = max_overlap + self.num_train = num_train + self.num_val = num_val + self.num_test = num_test + + self.emnist = EMNIST() + self.mapping = self.emnist.mapping + max_width = ( + int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) + + IMAGE_X_PADDING + ) + + if max_width >= IMAGE_WIDTH: + raise ValueError( + f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" + ) + + self.dims = ( + self.emnist.dims[0], + IMAGE_HEIGHT, + IMAGE_WIDTH + ) + + if self.max_length >= MAX_OUTPUT_LENGTH: + raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") + + self.output_dims = (MAX_OUTPUT_LENGTH, 1) + self.data_train = None + self.data_val = None + self.data_test = None + + @property + def data_filename(self) -> Path: + """Return name of dataset.""" + return ( + DATA_DIRNAME / (f"ml_{self.max_length}_" + f"o{self.min_overlap:f}_{self.max_overlap:f}_" + f"ntr{self.num_train}_" + f"ntv{self.num_val}_" + f"nte{self.num_test}.h5") + ) + + def prepare_data(self) -> None: + if self.data_filename.exists(): + return + np.random.seed(SEED) + self._generate_data("train") + self._generate_data("val") + self._generate_data("test") + + def setup(self, stage: str = None) -> None: + logger.info("EMNISTLinesDataset loading data from HDF5...") + if stage == "fit" or stage is None: + print(self.data_filename) + with h5py.File(self.data_filename, "r") as f: + x_train = f["x_train"][:] + y_train = torch.LongTensor(f["y_train"][:]) + x_val = f["x_val"][:] + y_val = torch.LongTensor(f["y_val"][:]) + + self.data_train = BaseDataset( + x_train, y_train, transform=_get_transform(augment=self.augment) + ) + self.data_val = BaseDataset( + x_val, y_val, transform=_get_transform(augment=self.augment) + ) + + if stage == "test" or stage is None: + with h5py.File(self.data_filename, "r") as f: + x_test = f["x_test"][:] + y_test = torch.LongTensor(f["y_test"][:]) + + self.data_test = BaseDataset( + x_test, y_test, transform=_get_transform(augment=False) + ) + + def __repr__(self) -> str: + """Return str about dataset.""" + basic = ( + "EMNISTLines2 Dataset\n" # pylint: disable=no-member + f"Min overlap: {self.min_overlap}\n" + f"Max overlap: {self.max_overlap}\n" + f"Num classes: {len(self.mapping)}\n" + f"Dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + def _generate_data(self, split: str) -> None: + logger.info(f"EMNISTLines generating data for {split}...") + sentence_generator = SentenceGenerator( + self.max_length - 2 + ) # Subtract by 2 because start/end token + + emnist = self.emnist + emnist.prepare_data() + emnist.setup() + + if split == "train": + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) + num = self.num_train + elif split == "val": + samples_by_char = _get_samples_by_char( + emnist.x_train, emnist.y_train, emnist.mapping + ) + num = self.num_val + else: + samples_by_char = _get_samples_by_char( + emnist.x_test, emnist.y_test, emnist.mapping + ) + num = self.num_test + + DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(self.data_filename, "a") as f: + x, y = _create_dataset_of_images( + num, + samples_by_char, + sentence_generator, + self.min_overlap, + self.max_overlap, + self.dims, + ) + y = convert_strings_to_labels( + y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + ) + f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") + f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") + + +def _get_samples_by_char( + samples: np.ndarray, labels: np.ndarray, mapping: Dict +) -> defaultdict: + samples_by_char = defaultdict(list) + for sample, label in zip(samples, labels): + samples_by_char[mapping[label]].append(sample) + return samples_by_char + + +def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): + null_image = torch.zeros((28, 28), dtype=torch.uint8) + sample_image_by_char = {} + for char in string: + if char in sample_image_by_char: + continue + samples = samples_by_char[char] + sample = samples[np.random.choice(len(samples))] if samples else null_image + sample_image_by_char[char] = sample.reshape(28, 28) + return [sample_image_by_char[char] for char in string] + + +def _construct_image_from_string( + string: str, + samples_by_char: defaultdict, + min_overlap: float, + max_overlap: float, + width: int, +) -> torch.Tensor: + overlap = np.random.uniform(min_overlap, max_overlap) + sampled_images = _select_letter_samples_for_string(string, samples_by_char) + N = len(sampled_images) + H, W = sampled_images[0].shape + next_overlap_width = W - int(overlap * W) + concatenated_image = torch.zeros((H, width), dtype=torch.uint8) + x = IMAGE_X_PADDING + for image in sampled_images: + concatenated_image[:, x : (x + W)] += image + x += next_overlap_width + return torch.minimum(torch.Tensor([255]), concatenated_image) + + +def _create_dataset_of_images( + num_samples: int, + samples_by_char: defaultdict, + sentence_generator: SentenceGenerator, + min_overlap: float, + max_overlap: float, + dims: Tuple, +) -> Tuple[torch.Tensor, torch.Tensor]: + images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) + labels = [] + for n in range(num_samples): + label = sentence_generator.generate() + crop = _construct_image_from_string( + label, samples_by_char, min_overlap, max_overlap, dims[-1] + ) + height = crop.shape[0] + y = (IMAGE_HEIGHT - height) // 2 + images[n, y : (y + height), :] = crop + labels.append(label) + return images, labels + + +def _get_transform(augment: bool = False) -> Callable: + if not augment: + return transforms.Compose([transforms.ToTensor()]) + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.ColorJitter(brightness=(0.5, 1.0)), + transforms.RandomAffine( + degrees=3, + translate=(0.0, 0.05), + scale=(0.4, 1.1), + shear=(-40, 50), + interpolation=InterpolationMode.BILINEAR, + fill=0, + ), + ] + ) + + +def generate_emnist_lines() -> None: + """Generates a synthetic handwritten dataset and displays info,""" + load_and_print_info(EMNISTLines) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py new file mode 100644 index 0000000..fcfe9a7 --- /dev/null +++ b/text_recognizer/data/iam.py @@ -0,0 +1,120 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from pathlib import Path +from typing import Any, Dict, List +import xml.etree.ElementTree as ElementTree +import zipfile + +from boltons.cacheutils import cachedproperty +from loguru import logger +from PIL import Image +import toml + +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.download_utils import download_dataset + + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" +EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. + + +class IAM(BaseDataModule): + """ + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels. + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + """ + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__(batch_size, num_workers) + self.metadata = toml.load(METADATA_FILENAME) + + def prepare_data(self) -> None: + if self.xml_filenames: + return + filename = download_dataset(self.metadata, DL_DATA_DIRNAME) + _extract_raw_dataset(filename, DL_DATA_DIRNAME) + + @property + def xml_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + @property + def form_filenames_by_id(self) -> Dict[str, Path]: + return {filename.stem: filename for filename in self.form_filenames} + + @property + def split_by_id(self) -> Dict[str, str]: + return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict[str, List[str]]: + """Return a dict from name of IAM form to list of line texts in it.""" + return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + + @cachedproperty + def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: + """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ("IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n") + + +def _extract_raw_dataset(filename: Path, dirname: Path) -> None: + logger.info("Extracting IAM data...") + curdir = os.getcwd() + os.chdir(dirname) + with zipfile.ZipFile(filename, "r") as f: + f.extractall() + os.chdir(curdir) + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace ": with ".""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get line region dict for each line.""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_file(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def download_iam() -> None: + load_and_print_info(IAM) diff --git a/text_recognizer/data/iam_dataset.py b/text_recognizer/data/iam_dataset.py new file mode 100644 index 0000000..a8998b9 --- /dev/null +++ b/text_recognizer/data/iam_dataset.py @@ -0,0 +1,133 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from typing import Any, Dict, List +import zipfile + +from boltons.cacheutils import cachedproperty +import defusedxml.ElementTree as ET +from loguru import logger +import toml + +from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME + +RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" +RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. + + +class IamDataset: + """IAM dataset. + + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + + """ + + def __init__(self) -> None: + self.metadata = toml.load(METADATA_FILENAME) + + def load_or_generate_data(self) -> None: + """Downloads IAM dataset if xml files does not exist.""" + if not self.xml_filenames: + self._download_iam() + + @property + def xml_filenames(self) -> List: + """List of xml filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List: + """List of forms filenames.""" + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + def _download_iam(self) -> None: + curdir = os.getcwd() + os.chdir(RAW_DATA_DIRNAME) + _download_raw_dataset(self.metadata) + _extract_raw_dataset(self.metadata) + os.chdir(curdir) + + @property + def form_filenames_by_id(self) -> Dict: + """Creates a dictionary with filenames as keys and forms as values.""" + return {filename.stem: filename for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of line texts in it.""" + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } + + @cachedproperty + def line_regions_by_id(self) -> Dict: + """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } + + def __repr__(self) -> str: + """Print info about dataset.""" + return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" + + +def _extract_raw_dataset(metadata: Dict) -> None: + logger.info("Extracting IAM data.") + with zipfile.ZipFile(metadata["filename"], "r") as zip_file: + zip_file.extractall() + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace " with ".""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get the line region dict for each line.""" + xml_root_element = ET.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_element(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: + """Extracts coordinates for each line of text.""" + # TODO: fix input! + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def main() -> None: + """Initializes the dataset and print info about the dataset.""" + dataset = IamDataset() + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/data/iam_lines_dataset.py b/text_recognizer/data/iam_lines_dataset.py new file mode 100644 index 0000000..1cb84bd --- /dev/null +++ b/text_recognizer/data/iam_lines_dataset.py @@ -0,0 +1,110 @@ +"""IamLinesDataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import h5py +from loguru import logger +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + + +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" +PROCESSED_DATA_URL = ( + "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" +) + + +class IamLinesDataset(Dataset): + """IAM lines datasets for handwritten text lines.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, + lower=lower, + ) + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self.data.shape[1:] if self.data is not None else None + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + if not PROCESSED_DATA_FILENAME.exists(): + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info("Downloading IAM lines...") + download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] + self._subsample() + + def __repr__(self) -> str: + """Print info about the dataset.""" + return ( + "IAM Lines Dataset\n" # pylint: disable=no-member + f"Number classes: {self.num_classes}\n" + f"Mapping: {self.mapper.mapping}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets diff --git a/text_recognizer/data/iam_paragraphs_dataset.py b/text_recognizer/data/iam_paragraphs_dataset.py new file mode 100644 index 0000000..8ba5142 --- /dev/null +++ b/text_recognizer/data/iam_paragraphs_dataset.py @@ -0,0 +1,291 @@ +"""IamParagraphsDataset class and functions for data processing.""" +import random +from typing import Callable, Dict, List, Optional, Tuple, Union + +import click +import cv2 +import h5py +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.transforms import ToTensor + +from text_recognizer import util +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.iam_dataset import IamDataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) + +INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" +DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" +PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" +CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" +GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" + +PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. +TEST_FRACTION = 0.2 +SEED = 4711 + + +class IamParagraphsDataset(Dataset): + """IAM Paragraphs dataset for paragraphs of handwritten text.""" + + def __init__( + self, + train: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) + # Load Iam dataset. + self.iam_dataset = IamDataset() + + self.num_classes = 3 + self._input_shape = (256, 256) + self._output_shape = self._input_shape + (self.num_classes,) + self._ids = None + + def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: + """Fetches data, target pair of the dataset for a given and index or indices. + + Args: + index (Union[int, Tensor]): Either a list or int of indices/index. + + Returns: + Tuple[Tensor, Tensor]: Data target pair. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + seed = np.random.randint(SEED) + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.transform: + data = self.transform(data) + + random.seed(seed) # apply this seed to target tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets.long() + + @property + def ids(self) -> Tensor: + """Ids of the dataset.""" + return self._ids + + def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: + """Get data target pair from id.""" + ind = self.ids.index(id_) + return self.data[ind], self.targets[ind] + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) + num_targets = len(self.iam_dataset.line_regions_by_id) + + if num_actual < num_targets - 2: + self._process_iam_paragraphs() + + self._data, self._targets, self._ids = _load_iam_paragraphs() + self._get_random_split() + self._subsample() + + def _get_random_split(self) -> None: + np.random.seed(SEED) + num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) + indices = np.random.permutation(self.data.shape[0]) + train_indices, test_indices = indices[:num_train], indices[num_train:] + if self.train: + self._data = self.data[train_indices] + self._targets = self.targets[train_indices] + else: + self._data = self.data[test_indices] + self._targets = self.targets[test_indices] + + def _process_iam_paragraphs(self) -> None: + """Crop the part with the text. + + For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are + self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel + corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line + """ + crop_dims = self._decide_on_crop_dims() + CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) + GT_DIRNAME.mkdir(parents=True, exist_ok=True) + logger.info( + f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" + ) + for filename in self.iam_dataset.form_filenames: + id_ = filename.stem + line_region = self.iam_dataset.line_regions_by_id[id_] + _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) + + def _decide_on_crop_dims(self) -> Tuple[int, int]: + """Decide on the dimensions to crop out of the form image. + + Since image width is larger than a comfortable crop around the longest paragraph, + we will make the crop a square form factor. + And since the found dimensions 610x610 are pretty close to 512x512, + we might as well resize crops and make it exactly that, which lets us + do all kinds of power-of-2 pooling and upsampling should we choose to. + + Returns: + Tuple[int, int]: A tuple of crop dimensions. + + Raises: + RuntimeError: When max crop height is larger than max crop width. + + """ + + sample_form_filename = self.iam_dataset.form_filenames[0] + sample_image = util.read_image(sample_form_filename, grayscale=True) + max_crop_width = sample_image.shape[1] + max_crop_height = _get_max_paragraph_crop_height( + self.iam_dataset.line_regions_by_id + ) + if not max_crop_height <= max_crop_width: + raise RuntimeError( + f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" + ) + + crop_dims = (max_crop_width, max_crop_width) + logger.info( + f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." + ) + logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") + return crop_dims + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ( + "IAM Paragraph Dataset\n" # pylint: disable=no-member + f"Num classes: {self.num_classes}\n" + f"Data: {self.data.shape}\n" + f"Targets: {self.targets.shape}\n" + ) + + +def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: + heights = [] + for regions in line_regions_by_id.values(): + min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + heights.append(height) + return max(heights) + + +def _crop_paragraph_image( + filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple +) -> None: + image = util.read_image(filename, grayscale=True) + + min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER + max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER + height = max_y2 - min_y1 + crop_height = crop_dims[0] + buffer = (crop_height - height) // 2 + + # Generate image crop. + image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) + try: + image_crop[buffer : buffer + height] = image[min_y1:max_y2] + except Exception as e: # pylint: disable=broad-except + logger.error(f"Rescued {filename}: {e}") + return + + # Generate ground truth. + gt_image = np.zeros_like(image_crop, dtype=np.uint8) + for index, region in enumerate(line_regions): + gt_image[ + (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), + region["x1"] : region["x2"], + ] = (index % 2 + 1) + + # Generate image for debugging. + import matplotlib.pyplot as plt + + cmap = plt.get_cmap("Set1") + image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) + for index, region in enumerate(line_regions): + color = [255 * _ for _ in cmap(index)[:-1]] + cv2.rectangle( + image_crop_for_debug, + (region["x1"], region["y1"] - min_y1 + buffer), + (region["x2"], region["y2"] - min_y1 + buffer), + color, + 3, + ) + image_crop_for_debug = cv2.resize( + image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA + ) + util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") + + image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) + util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") + + gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) + util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") + + +def _load_iam_paragraphs() -> None: + logger.info("Loading IAM paragraph crops and ground truth from image files...") + images = [] + gt_images = [] + ids = [] + for filename in CROPS_DIRNAME.glob("*.jpg"): + id_ = filename.stem + image = util.read_image(filename, grayscale=True) + image = 1.0 - image / 255 + + gt_filename = GT_DIRNAME / f"{id_}.png" + gt_image = util.read_image(gt_filename, grayscale=True) + + images.append(image) + gt_images.append(gt_image) + ids.append(id_) + images = np.array(images).astype(np.float32) + gt_images = np.array(gt_images).astype(np.uint8) + ids = np.array(ids) + return images, gt_images, ids + + +@click.command() +@click.option( + "--subsample_fraction", + type=float, + default=None, + help="The subsampling factor of the dataset.", +) +def main(subsample_fraction: float) -> None: + """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") + dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + + +if __name__ == "__main__": + main() diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py new file mode 100644 index 0000000..a93eb00 --- /dev/null +++ b/text_recognizer/data/iam_preprocessor.py @@ -0,0 +1,196 @@ +"""Preprocessor for extracting word letters from the IAM dataset. + +The code is mostly stolen from: + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + +""" + +import collections +import itertools +from pathlib import Path +import re +from typing import List, Optional, Union + +import click +from loguru import logger +import torch + + +def load_metadata( + data_dir: Path, wordsep: str, use_words: bool = False +) -> collections.defaultdict: + """Loads IAM metadata and returns it as a dictionary.""" + forms = collections.defaultdict(list) + filename = "words.txt" if use_words else "lines.txt" + + with open(data_dir / "ascii" / filename, "r") as f: + lines = (line.strip().split() for line in f if line[0] != "#") + for line in lines: + # Skip word segmentation errors. + if use_words and line[1] == "err": + continue + text = " ".join(line[8:]) + + # Remove garbage tokens: + text = text.replace("#", "") + + # Swap word sep form | to wordsep + text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) + form_key = "-".join(line[0].split("-")[:2]) + line_key = "-".join(line[0].split("-")[:3]) + box_idx = 4 - use_words + box = tuple(int(val) for val in line[box_idx : box_idx + 4]) + forms[form_key].append({"key": line_key, "box": box, "text": text}) + return forms + + +class Preprocessor: + """A preprocessor for the IAM dataset.""" + + # TODO: add lower case only to when generating... + + def __init__( + self, + data_dir: Union[str, Path], + num_features: int, + tokens_path: Optional[Union[str, Path]] = None, + lexicon_path: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + self.wordsep = "▁" + self._use_word = use_words + self._prepend_wordsep = prepend_wordsep + + self.data_dir = Path(data_dir) + + self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) + + # Load the set of graphemes: + graphemes = set() + for _, form in self.forms.items(): + for line in form: + graphemes.update(line["text"].lower()) + self.graphemes = sorted(graphemes) + + # Build the token-to-index and index-to-token maps. + if tokens_path is not None: + with open(tokens_path, "r") as f: + self.tokens = [line.strip() for line in f] + else: + self.tokens = self.graphemes + + if lexicon_path is not None: + with open(lexicon_path, "r") as f: + lexicon = (line.strip().split() for line in f) + lexicon = {line[0]: line[1:] for line in lexicon} + self.lexicon = lexicon + else: + self.lexicon = None + + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} + self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} + self.num_features = num_features + self.text = [] + + @property + def num_tokens(self) -> int: + """Returns the number or tokens.""" + return len(self.tokens) + + @property + def use_words(self) -> bool: + """If words are used.""" + return self._use_word + + def extract_train_text(self) -> None: + """Extracts training text.""" + keys = [] + with open(self.data_dir / "task" / "trainset.txt") as f: + keys.extend((line.strip() for line in f)) + + for _, examples in self.forms.items(): + for example in examples: + if example["key"] not in keys: + continue + self.text.append(example["text"].lower()) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + token_to_index = self.graphemes_to_index + if self.lexicon is not None: + if len(line) > 0: + # If the word is not found in the lexicon, fall back to letters. + line = [ + t + for w in line.split(self.wordsep) + for t in self.lexicon.get(w, self.wordsep + w) + ] + token_to_index = self.tokens_to_index + if self._prepend_wordsep: + line = itertools.chain([self.wordsep], line) + return torch.LongTensor([token_to_index[t] for t in line]) + + def to_text(self, indices: List[int]) -> str: + """Converts indices to text.""" + # Roughly the inverse of `to_index` + encoding = self.graphemes + if self.lexicon is not None: + encoding = self.tokens + return self._post_process(encoding[i] for i in indices) + + def tokens_to_text(self, indices: List[int]) -> str: + """Converts tokens to text.""" + return self._post_process(self.tokens[i] for i in indices) + + def _post_process(self, indices: List[int]) -> str: + """A list join.""" + return "".join(indices).strip(self.wordsep) + + +@click.command() +@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") +@click.option( + "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" +) +@click.option( + "--save_text", type=str, default=None, help="Path to save parsed train text" +) +@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") +def cli( + data_dir: Optional[str], + use_words: bool, + save_text: Optional[str], + save_tokens: Optional[str], +) -> None: + """CLI for extracting text data from the iam dataset.""" + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + + preprocessor = Preprocessor(data_dir, 64, use_words=use_words) + preprocessor.extract_train_text() + + processed_dir = data_dir.parents[2] / "processed" / "iam_lines" + logger.debug(f"Saving processed files at: {processed_dir}") + + if save_text is not None: + logger.info("Saving training text") + with open(processed_dir / save_text, "w") as f: + f.write("\n".join(t for t in preprocessor.text)) + + if save_tokens is not None: + logger.info("Saving tokens") + with open(processed_dir / save_tokens, "w") as f: + f.write("\n".join(preprocessor.tokens)) + + +if __name__ == "__main__": + cli() diff --git a/text_recognizer/data/sentence_generator.py b/text_recognizer/data/sentence_generator.py new file mode 100644 index 0000000..53b781c --- /dev/null +++ b/text_recognizer/data/sentence_generator.py @@ -0,0 +1,85 @@ +"""Downloading the Brown corpus with NLTK for sentence generating.""" + +import itertools +import re +import string +from typing import Optional + +import nltk +from nltk.corpus.reader.util import ConcatenatedCorpusView +import numpy as np + +from text_recognizer.datasets.util import DATA_DIRNAME + +NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" + + +class SentenceGenerator: + """Generates text sentences using the Brown corpus.""" + + def __init__(self, max_length: Optional[int] = None) -> None: + """Loads the corpus and sets word start indices.""" + self.corpus = brown_corpus() + self.word_start_indices = [0] + [ + _.start(0) + 1 for _ in re.finditer(" ", self.corpus) + ] + self.max_length = max_length + + def generate(self, max_length: Optional[int] = None) -> str: + """Generates a word or sentences from the Brown corpus. + + Sample a string from the Brown corpus of length at least one word and at most max_length, padding to + max_length with the '_' characters if sentence is shorter. + + Args: + max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. + + Returns: + str: A sentence from the Brown corpus. + + Raises: + ValueError: If max_length was not specified at initialization and not given as an argument. + + """ + if max_length is None: + max_length = self.max_length + if max_length is None: + raise ValueError( + "Must provide max_length to this method or when making this object." + ) + + for _ in range(10): + try: + index = np.random.randint(0, len(self.word_start_indices) - 1) + start_index = self.word_start_indices[index] + end_index_candidates = [] + for index in range(index + 1, len(self.word_start_indices)): + if self.word_start_indices[index] - start_index > max_length: + break + end_index_candidates.append(self.word_start_indices[index]) + end_index = np.random.choice(end_index_candidates) + sampled_text = self.corpus[start_index:end_index].strip() + return sampled_text + except Exception: + pass + raise RuntimeError("Was not able to generate a valid string") + + +def brown_corpus() -> str: + """Returns a single string with the Brown corpus with all punctuations stripped.""" + sentences = load_nltk_brown_corpus() + corpus = " ".join(itertools.chain.from_iterable(sentences)) + corpus = corpus.translate({ord(c): None for c in string.punctuation}) + corpus = re.sub(" +", " ", corpus) + return corpus + + +def load_nltk_brown_corpus() -> ConcatenatedCorpusView: + """Load the Brown corpus using the NLTK library.""" + nltk.data.path.append(NLTK_DATA_DIRNAME) + try: + nltk.corpus.brown.sents() + except LookupError: + NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) + return nltk.corpus.brown.sents() diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py new file mode 100644 index 0000000..b6a48f5 --- /dev/null +++ b/text_recognizer/data/transforms.py @@ -0,0 +1,266 @@ +"""Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path +import random +from typing import Any, Optional, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +import torch.nn.functional as F +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) + +from text_recognizer.datasets.iam_preprocessor import Preprocessor +from text_recognizer.datasets.util import EmnistMapper + + +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +class Resize: + """Resizes a tensor to a specified width.""" + + def __init__(self, width: int = 952) -> None: + # The default is 952 because of the IAM dataset. + self.width = width + + def __call__(self, image: Tensor) -> Tensor: + """Resize tensor in the last dimension.""" + return F.interpolate(image, size=self.width, mode="nearest") + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target + + +class ApplyContrast: + """Sets everything below a threshold to zero, i.e. increase contrast.""" + + def __init__(self, low: float = 0.0, high: float = 0.25) -> None: + self.low = low + self.high = high + + def __call__(self, x: Tensor) -> Tensor: + """Apply mask binary mask to input tensor.""" + mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) + return x * mask + + +class Unsqueeze: + """Add a dimension to the tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Adds dim.""" + return x.unsqueeze(0) + + +class Squeeze: + """Removes the first dimension of a tensor.""" + + def __call__(self, x: Tensor) -> Tensor: + """Removes first dim.""" + return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/data/util.py b/text_recognizer/data/util.py new file mode 100644 index 0000000..da87756 --- /dev/null +++ b/text_recognizer/data/util.py @@ -0,0 +1,209 @@ +"""Util functions for datasets.""" +import hashlib +import json +import os +from pathlib import Path +import string +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve + +from loguru import logger +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from tqdm import tqdm + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) + + +def download_emnist() -> None: + """Download the EMNIST dataset via the PyTorch class.""" + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) + + +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.lower = lower + + self.essentials = self._load_emnist_essentials() + # Load dataset information. + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self) -> None: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol, and acts as blank symbol as well. + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) + + max_key = max(self.mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + self._mapping = {**self.mapping, **extra_mapping} + + +def compute_sha256(filename: Union[Path, str]) -> str: + """Returns the SHA256 checksum of a file.""" + with open(filename, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def _download_raw_dataset(metadata: Dict) -> None: + if os.path.exists(metadata["filename"]): + return + logger.info(f"Downloading raw dataset from {metadata['url']}...") + download_url(metadata["url"], metadata["filename"]) + logger.info("Computing SHA-256...") + sha256 = compute_sha256(metadata["filename"]) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py deleted file mode 100644 index 2727b20..0000000 --- a/text_recognizer/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Dataset modules.""" diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py deleted file mode 100644 index f5e7300..0000000 --- a/text_recognizer/datasets/base_data_module.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Base lightning DataModule class.""" -from pathlib import Path -from typing import Dict - -import pytorch_lightning as pl -from torch.utils.data import DataLoader - - -def load_and_print_info(data_module_class: type) -> None: - """Load EMNISTLines and prints info.""" - dataset = data_module_class() - dataset.prepare_data() - dataset.setup() - print(dataset) - - -class BaseDataModule(pl.LightningDataModule): - """Base PyTorch Lightning DataModule.""" - - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - - # Placeholders for subclasses. - self.dims = None - self.output_dims = None - self.mapping = None - - @classmethod - def data_dirname(cls) -> Path: - """Return the path to the base data directory.""" - return Path(__file__).resolve().parents[2] / "data" - - def config(self) -> Dict: - """Return important settings of the dataset.""" - return { - "input_dim": self.dims, - "output_dims": self.output_dims, - "mapping": self.mapping, - } - - def prepare_data(self) -> None: - """Prepare data for training.""" - pass - - def setup(self, stage: str = None) -> None: - """Split into train, val, test, and set dims. - - Should assign `torch Dataset` objects to self.data_train, self.data_val, and - optionally self.data_test. - - Args: - stage (Any): Variable to set splits. - - """ - self.data_train = None - self.data_val = None - self.data_test = None - - def train_dataloader(self) -> DataLoader: - """Retun DataLoader for train data.""" - return DataLoader( - self.data_train, - shuffle=True, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: - """Return DataLoader for val data.""" - return DataLoader( - self.data_val, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: - """Return DataLoader for val data.""" - return DataLoader( - self.data_test, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py deleted file mode 100644 index a9e9c24..0000000 --- a/text_recognizer/datasets/base_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union - -import torch -from torch import Tensor -from torch.utils.data import Dataset - - -class BaseDataset(Dataset): - """ - Base Dataset class that processes data and targets through optional transfroms. - - Args: - data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. - targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. - tranform (Callable): Function that takes a datum and applies transforms. - target_transform (Callable): Fucntion that takes a target and applies - target transforms. - """ - - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): - raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform - - def __len__(self) -> int: - """Return the length of the dataset.""" - return len(self.data) - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """Return a datum and its target, after processing by transforms. - - Args: - index (int): Index of a datum in the dataset. - - Returns: - Tuple[Any, Any]: Datum and target pair. - - """ - datum, target = self.data[index], self.targets[index] - - if self.transform is not None: - datum = self.transform(datum) - - if self.target_transform is not None: - target = self.target_transform(target) - - return datum, target - - -def convert_strings_to_labels( - strings: Sequence[str], mapping: Dict[str, int], length: int -) -> Tensor: - """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with and tokens, - and padded wiht the

token. - """ - labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] - for i, string in enumerate(strings): - tokens = list(string) - tokens = ["", *tokens, ""] - for j, token in enumerate(tokens): - labels[i, j] = mapping[token] - return labels diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py deleted file mode 100644 index e3dc68c..0000000 --- a/text_recognizer/datasets/download_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Util functions for downloading datasets.""" -import hashlib -from pathlib import Path -from typing import Dict, List, Optional -from urllib.request import urlretrieve - -from loguru import logger -from tqdm import tqdm - - -def _compute_sha256(filename: Path) -> str: - """Returns the SHA256 checksum of a file.""" - with filename.open(mode="rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def _download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: - """Downloads dataset using a metadata file. - - Args: - metadata (Dict): A metadata file of the dataset. - dl_dir (Path): Download directory for the dataset. - - Returns: - Optional[Path]: Returns filename if dataset is downloaded, None if it already - exists. - - Raises: - ValueError: If the SHA-256 value is not the same between the dataset and - the metadata file. - - """ - dl_dir.mkdir(parents=True, exist_ok=True) - filename = dl_dir / metadata["filename"] - if filename.exists(): - return - logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") - _download_url(metadata["url"], filename) - logger.info("Computing the SHA-256...") - sha256 = _compute_sha256(filename) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) - return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py deleted file mode 100644 index 66101b5..0000000 --- a/text_recognizer/datasets/emnist.py +++ /dev/null @@ -1,210 +0,0 @@ -"""EMNIST dataset: downloads it from FSDL aws url if not present.""" -from pathlib import Path -from typing import Sequence, Tuple -import json -import os -import shutil -import zipfile - -import h5py -import numpy as np -from loguru import logger -import toml -import torch -from torch.utils.data import random_split -from torchvision import transforms - -from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.download_utils import download_dataset - - -SEED = 4711 -NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" - - -class EMNIST(BaseDataModule): - """ - "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 - and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." - From https://www.nist.gov/itl/iad/image-group/emnist-dataset - - The data split we will use is - EMNIST ByClass: 814,255 characters. 62 unbalanced classes. - """ - - def __init__( - self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 - ) -> None: - super().__init__(batch_size, num_workers) - if not ESSENTIALS_FILENAME.exists(): - _download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) - self.train_fraction = train_fraction - self.mapping = list(essentials["characters"]) - self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} - self.data_train = None - self.data_val = None - self.data_test = None - self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, *essentials["input_shape"]) - self.output_dims = (1,) - - def prepare_data(self) -> None: - if not PROCESSED_DATA_FILENAME.exists(): - _download_and_process_emnist() - - def setup(self, stage: str = None) -> None: - if stage == "fit" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_train = f["x_train"][:] - self.y_train = f["y_train"][:].squeeze().astype(int) - - dataset_train = BaseDataset( - self.x_train, self.y_train, transform=self.transform - ) - train_size = int(self.train_fraction * len(dataset_train)) - val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split( - dataset_train, [train_size, val_size], generator=torch.Generator() - ) - - if stage == "test" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_test = f["x_test"][:] - self.y_test = f["y_test"][:].squeeze().astype(int) - self.data_test = BaseDataset( - self.x_test, self.y_test, transform=self.transform - ) - - def __repr__(self) -> str: - basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - datum, target = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" - f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" - ) - - return basic + data - - -def _download_and_process_emnist() -> None: - metadata = toml.load(METADATA_FILENAME) - download_dataset(metadata, DL_DATA_DIRNAME) - _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) - - -def _process_raw_dataset(filename: str, dirname: Path) -> None: - logger.info("Unzipping EMNIST...") - curdir = os.getcwd() - os.chdir(dirname) - content = zipfile.ZipFile(filename, "r") - content.extract("matlab/emnist-byclass.mat") - - from scipy.io import loadmat - - logger.info("Loading training data from .mat file") - data = loadmat("matlab/emnist-byclass.mat") - x_train = ( - data["dataset"]["train"][0, 0]["images"][0, 0] - .reshape(-1, 28, 28) - .swapaxes(1, 2) - ) - y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = ( - data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) - ) - y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - - if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") - x_train, y_train = _sample_to_balance(x_train, y_train) - x_test, y_test = _sample_to_balance(x_test, y_test) - - logger.info("Saving to HDF5 in a compressed format...") - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: - f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") - f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") - f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") - f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") - mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} - characters = _augment_emnist_characters(mapping.values()) - essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} - - with ESSENTIALS_FILENAME.open(mode="w") as f: - json.dump(essentials, f) - - logger.info("Cleaning up...") - shutil.rmtree("matlab") - os.chdir(curdir) - - -def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Balances the dataset by taking the mean number of instances per class.""" - np.random.seed(SEED) - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - indices = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - return x_sampled, y_sampled - - -def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: - """Augment the mapping with extra symbols.""" - # Extra characters from the IAM dataset. - iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # Also add special tokens for: - # - CTC blank token at index 0 - # - Start token at index 1 - # - End token at index 2 - # - Padding token at index 3 - # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! - return ["", "", "", "

", *characters, *iam_characters] - - -def download_emnist() -> None: - """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json deleted file mode 100644 index 3f46a73..0000000 --- a/text_recognizer/datasets/emnist_essentials.json +++ /dev/null @@ -1 +0,0 @@ -{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py deleted file mode 100644 index 9ebad22..0000000 --- a/text_recognizer/datasets/emnist_lines.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Dataset of generated text from EMNIST characters.""" -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence - -import h5py -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode - -from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.emnist import EMNIST -from text_recognizer.datasets.sentence_generator import SentenceGenerator - - -DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" -) - -SEED = 4711 -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -IMAGE_X_PADDING = 28 -MAX_OUTPUT_LENGTH = 89 # Same as IAMLines - - -class EMNISTLines(BaseDataModule): - """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" - - def __init__( - self, - augment: bool = True, - batch_size: int = 128, - num_workers: int = 0, - max_length: int = 32, - min_overlap: float = 0.0, - max_overlap: float = 0.33, - num_train: int = 10_000, - num_val: int = 2_000, - num_test: int = 2_000, - ) -> None: - super().__init__(batch_size, num_workers) - - self.augment = augment - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_train = num_train - self.num_val = num_val - self.num_test = num_test - - self.emnist = EMNIST() - self.mapping = self.emnist.mapping - max_width = ( - int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) - + IMAGE_X_PADDING - ) - - if max_width >= IMAGE_WIDTH: - raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" - ) - - self.dims = ( - self.emnist.dims[0], - IMAGE_HEIGHT, - IMAGE_WIDTH - ) - - if self.max_length >= MAX_OUTPUT_LENGTH: - raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") - - self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train = None - self.data_val = None - self.data_test = None - - @property - def data_filename(self) -> Path: - """Return name of dataset.""" - return ( - DATA_DIRNAME / (f"ml_{self.max_length}_" - f"o{self.min_overlap:f}_{self.max_overlap:f}_" - f"ntr{self.num_train}_" - f"ntv{self.num_val}_" - f"nte{self.num_test}.h5") - ) - - def prepare_data(self) -> None: - if self.data_filename.exists(): - return - np.random.seed(SEED) - self._generate_data("train") - self._generate_data("val") - self._generate_data("test") - - def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") - if stage == "fit" or stage is None: - print(self.data_filename) - with h5py.File(self.data_filename, "r") as f: - x_train = f["x_train"][:] - y_train = torch.LongTensor(f["y_train"][:]) - x_val = f["x_val"][:] - y_val = torch.LongTensor(f["y_val"][:]) - - self.data_train = BaseDataset( - x_train, y_train, transform=_get_transform(augment=self.augment) - ) - self.data_val = BaseDataset( - x_val, y_val, transform=_get_transform(augment=self.augment) - ) - - if stage == "test" or stage is None: - with h5py.File(self.data_filename, "r") as f: - x_test = f["x_test"][:] - y_test = torch.LongTensor(f["y_test"][:]) - - self.data_test = BaseDataset( - x_test, y_test, transform=_get_transform(augment=False) - ) - - def __repr__(self) -> str: - """Return str about dataset.""" - basic = ( - "EMNISTLines2 Dataset\n" # pylint: disable=no-member - f"Min overlap: {self.min_overlap}\n" - f"Max overlap: {self.max_overlap}\n" - f"Num classes: {len(self.mapping)}\n" - f"Dims: {self.dims}\n" - f"Output dims: {self.output_dims}\n" - ) - - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - x, y = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - ) - return basic + data - - def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") - sentence_generator = SentenceGenerator( - self.max_length - 2 - ) # Subtract by 2 because start/end token - - emnist = self.emnist - emnist.prepare_data() - emnist.setup() - - if split == "train": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_train - elif split == "val": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_val - else: - samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping - ) - num = self.num_test - - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "a") as f: - x, y = _create_dataset_of_images( - num, - samples_by_char, - sentence_generator, - self.min_overlap, - self.max_overlap, - self.dims, - ) - y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH - ) - f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") - f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") - - -def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict -) -> defaultdict: - samples_by_char = defaultdict(list) - for sample, label in zip(samples, labels): - samples_by_char[mapping[label]].append(sample) - return samples_by_char - - -def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): - null_image = torch.zeros((28, 28), dtype=torch.uint8) - sample_image_by_char = {} - for char in string: - if char in sample_image_by_char: - continue - samples = samples_by_char[char] - sample = samples[np.random.choice(len(samples))] if samples else null_image - sample_image_by_char[char] = sample.reshape(28, 28) - return [sample_image_by_char[char] for char in string] - - -def _construct_image_from_string( - string: str, - samples_by_char: defaultdict, - min_overlap: float, - max_overlap: float, - width: int, -) -> torch.Tensor: - overlap = np.random.uniform(min_overlap, max_overlap) - sampled_images = _select_letter_samples_for_string(string, samples_by_char) - N = len(sampled_images) - H, W = sampled_images[0].shape - next_overlap_width = W - int(overlap * W) - concatenated_image = torch.zeros((H, width), dtype=torch.uint8) - x = IMAGE_X_PADDING - for image in sampled_images: - concatenated_image[:, x : (x + W)] += image - x += next_overlap_width - return torch.minimum(torch.Tensor([255]), concatenated_image) - - -def _create_dataset_of_images( - num_samples: int, - samples_by_char: defaultdict, - sentence_generator: SentenceGenerator, - min_overlap: float, - max_overlap: float, - dims: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor]: - images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) - labels = [] - for n in range(num_samples): - label = sentence_generator.generate() - crop = _construct_image_from_string( - label, samples_by_char, min_overlap, max_overlap, dims[-1] - ) - height = crop.shape[0] - y = (IMAGE_HEIGHT - height) // 2 - images[n, y : (y + height), :] = crop - labels.append(label) - return images, labels - - -def _get_transform(augment: bool = False) -> Callable: - if not augment: - return transforms.Compose([transforms.ToTensor()]) - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.ColorJitter(brightness=(0.5, 1.0)), - transforms.RandomAffine( - degrees=3, - translate=(0.0, 0.05), - scale=(0.4, 1.1), - shear=(-40, 50), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - ) - - -def generate_emnist_lines() -> None: - """Generates a synthetic handwritten dataset and displays info,""" - load_and_print_info(EMNISTLines) diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py deleted file mode 100644 index a8998b9..0000000 --- a/text_recognizer/datasets/iam_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" -RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. - - -class IamDataset: - """IAM dataset. - - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - - The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. - The validation set has been merged into the train set. - The train set has 7,101 lines from 326 writers. - The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - - """ - - def __init__(self) -> None: - self.metadata = toml.load(METADATA_FILENAME) - - def load_or_generate_data(self) -> None: - """Downloads IAM dataset if xml files does not exist.""" - if not self.xml_filenames: - self._download_iam() - - @property - def xml_filenames(self) -> List: - """List of xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - - @property - def form_filenames(self) -> List: - """List of forms filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - - def _download_iam(self) -> None: - curdir = os.getcwd() - os.chdir(RAW_DATA_DIRNAME) - _download_raw_dataset(self.metadata) - _extract_raw_dataset(self.metadata) - os.chdir(curdir) - - @property - def form_filenames_by_id(self) -> Dict: - """Creates a dictionary with filenames as keys and forms as values.""" - return {filename.stem: filename for filename in self.form_filenames} - - @cachedproperty - def line_strings_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of line texts in it.""" - return { - filename.stem: _get_line_strings_from_xml_file(filename) - for filename in self.xml_filenames - } - - @cachedproperty - def line_regions_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return { - filename.stem: _get_line_regions_from_xml_file(filename) - for filename in self.xml_filenames - } - - def __repr__(self) -> str: - """Print info about dataset.""" - return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: - logger.info("Extracting IAM data.") - with zipfile.ZipFile(metadata["filename"], "r") as zip_file: - zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: - """Get the text content of each line. Note that we replace " with ".""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: - """Get the line region dict for each line.""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: - """Extracts coordinates for each line of text.""" - # TODO: fix input! - word_elements = xml_line.findall("word/cmp") - x1s = [int(el.attrib["x"]) for el in word_elements] - y1s = [int(el.attrib["y"]) for el in word_elements] - x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] - return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } - - -def main() -> None: - """Initializes the dataset and print info about the dataset.""" - dataset = IamDataset() - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/text_recognizer/datasets/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( - "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): - """IAM lines datasets for handwritten text lines.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - init_token=init_token, - pad_token=pad_token, - eos_token=eos_token, - lower=lower, - ) - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self.data.shape[1:] if self.data is not None else None - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return ( - self.targets.shape[1:] + (self.num_classes,) - if self.targets is not None - else None - ) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - if not PROCESSED_DATA_FILENAME.exists(): - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info("Downloading IAM lines...") - download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self._data = f[f"x_{self.split}"][:] - self._targets = f[f"y_{self.split}"][:] - self._subsample() - - def __repr__(self) -> str: - """Print info about the dataset.""" - return ( - "IAM Lines Dataset\n" # pylint: disable=no-member - f"Number classes: {self.num_classes}\n" - f"Mapping: {self.mapper.mapping}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/text_recognizer/datasets/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - ) - # Load Iam dataset. - self.iam_dataset = IamDataset() - - self.num_classes = 3 - self._input_shape = (256, 256) - self._output_shape = self._input_shape + (self.num_classes,) - self._ids = None - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - seed = np.random.randint(SEED) - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.transform: - data = self.transform(data) - - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets.long() - - @property - def ids(self) -> Tensor: - """Ids of the dataset.""" - return self._ids - - def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: - """Get data target pair from id.""" - ind = self.ids.index(id_) - return self.data[ind], self.targets[ind] - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) - num_targets = len(self.iam_dataset.line_regions_by_id) - - if num_actual < num_targets - 2: - self._process_iam_paragraphs() - - self._data, self._targets, self._ids = _load_iam_paragraphs() - self._get_random_split() - self._subsample() - - def _get_random_split(self) -> None: - np.random.seed(SEED) - num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) - indices = np.random.permutation(self.data.shape[0]) - train_indices, test_indices = indices[:num_train], indices[num_train:] - if self.train: - self._data = self.data[train_indices] - self._targets = self.targets[train_indices] - else: - self._data = self.data[test_indices] - self._targets = self.targets[test_indices] - - def _process_iam_paragraphs(self) -> None: - """Crop the part with the text. - - For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are - self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel - corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line - """ - crop_dims = self._decide_on_crop_dims() - CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - GT_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info( - f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" - ) - for filename in self.iam_dataset.form_filenames: - id_ = filename.stem - line_region = self.iam_dataset.line_regions_by_id[id_] - _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - - def _decide_on_crop_dims(self) -> Tuple[int, int]: - """Decide on the dimensions to crop out of the form image. - - Since image width is larger than a comfortable crop around the longest paragraph, - we will make the crop a square form factor. - And since the found dimensions 610x610 are pretty close to 512x512, - we might as well resize crops and make it exactly that, which lets us - do all kinds of power-of-2 pooling and upsampling should we choose to. - - Returns: - Tuple[int, int]: A tuple of crop dimensions. - - Raises: - RuntimeError: When max crop height is larger than max crop width. - - """ - - sample_form_filename = self.iam_dataset.form_filenames[0] - sample_image = util.read_image(sample_form_filename, grayscale=True) - max_crop_width = sample_image.shape[1] - max_crop_height = _get_max_paragraph_crop_height( - self.iam_dataset.line_regions_by_id - ) - if not max_crop_height <= max_crop_width: - raise RuntimeError( - f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" - ) - - crop_dims = (max_crop_width, max_crop_width) - logger.info( - f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." - ) - logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") - return crop_dims - - def __repr__(self) -> str: - """Return info about the dataset.""" - return ( - "IAM Paragraph Dataset\n" # pylint: disable=no-member - f"Num classes: {self.num_classes}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: - heights = [] - for regions in line_regions_by_id.values(): - min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - heights.append(height) - return max(heights) - - -def _crop_paragraph_image( - filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: - image = util.read_image(filename, grayscale=True) - - min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - crop_height = crop_dims[0] - buffer = (crop_height - height) // 2 - - # Generate image crop. - image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) - try: - image_crop[buffer : buffer + height] = image[min_y1:max_y2] - except Exception as e: # pylint: disable=broad-except - logger.error(f"Rescued {filename}: {e}") - return - - # Generate ground truth. - gt_image = np.zeros_like(image_crop, dtype=np.uint8) - for index, region in enumerate(line_regions): - gt_image[ - (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), - region["x1"] : region["x2"], - ] = (index % 2 + 1) - - # Generate image for debugging. - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("Set1") - image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) - for index, region in enumerate(line_regions): - color = [255 * _ for _ in cmap(index)[:-1]] - cv2.rectangle( - image_crop_for_debug, - (region["x1"], region["y1"] - min_y1 + buffer), - (region["x2"], region["y2"] - min_y1 + buffer), - color, - 3, - ) - image_crop_for_debug = cv2.resize( - image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA - ) - util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - - image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) - util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - - gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) - util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: - logger.info("Loading IAM paragraph crops and ground truth from image files...") - images = [] - gt_images = [] - ids = [] - for filename in CROPS_DIRNAME.glob("*.jpg"): - id_ = filename.stem - image = util.read_image(filename, grayscale=True) - image = 1.0 - image / 255 - - gt_filename = GT_DIRNAME / f"{id_}.png" - gt_image = util.read_image(gt_filename, grayscale=True) - - images.append(image) - gt_images.append(gt_image) - ids.append(id_) - images = np.array(images).astype(np.float32) - gt_images = np.array(gt_images).astype(np.uint8) - ids = np.array(ids) - return images, gt_images, ids - - -@click.command() -@click.option( - "--subsample_fraction", - type=float, - default=None, - help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: - """Load dataset and print info.""" - logger.info("Creating train set...") - dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - logger.info("Creating test set...") - dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py deleted file mode 100644 index a93eb00..0000000 --- a/text_recognizer/datasets/iam_preprocessor.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Preprocessor for extracting word letters from the IAM dataset. - -The code is mostly stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - -""" - -import collections -import itertools -from pathlib import Path -import re -from typing import List, Optional, Union - -import click -from loguru import logger -import torch - - -def load_metadata( - data_dir: Path, wordsep: str, use_words: bool = False -) -> collections.defaultdict: - """Loads IAM metadata and returns it as a dictionary.""" - forms = collections.defaultdict(list) - filename = "words.txt" if use_words else "lines.txt" - - with open(data_dir / "ascii" / filename, "r") as f: - lines = (line.strip().split() for line in f if line[0] != "#") - for line in lines: - # Skip word segmentation errors. - if use_words and line[1] == "err": - continue - text = " ".join(line[8:]) - - # Remove garbage tokens: - text = text.replace("#", "") - - # Swap word sep form | to wordsep - text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) - form_key = "-".join(line[0].split("-")[:2]) - line_key = "-".join(line[0].split("-")[:3]) - box_idx = 4 - use_words - box = tuple(int(val) for val in line[box_idx : box_idx + 4]) - forms[form_key].append({"key": line_key, "box": box, "text": text}) - return forms - - -class Preprocessor: - """A preprocessor for the IAM dataset.""" - - # TODO: add lower case only to when generating... - - def __init__( - self, - data_dir: Union[str, Path], - num_features: int, - tokens_path: Optional[Union[str, Path]] = None, - lexicon_path: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - self.wordsep = "▁" - self._use_word = use_words - self._prepend_wordsep = prepend_wordsep - - self.data_dir = Path(data_dir) - - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) - - # Load the set of graphemes: - graphemes = set() - for _, form in self.forms.items(): - for line in form: - graphemes.update(line["text"].lower()) - self.graphemes = sorted(graphemes) - - # Build the token-to-index and index-to-token maps. - if tokens_path is not None: - with open(tokens_path, "r") as f: - self.tokens = [line.strip() for line in f] - else: - self.tokens = self.graphemes - - if lexicon_path is not None: - with open(lexicon_path, "r") as f: - lexicon = (line.strip().split() for line in f) - lexicon = {line[0]: line[1:] for line in lexicon} - self.lexicon = lexicon - else: - self.lexicon = None - - self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} - self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} - self.num_features = num_features - self.text = [] - - @property - def num_tokens(self) -> int: - """Returns the number or tokens.""" - return len(self.tokens) - - @property - def use_words(self) -> bool: - """If words are used.""" - return self._use_word - - def extract_train_text(self) -> None: - """Extracts training text.""" - keys = [] - with open(self.data_dir / "task" / "trainset.txt") as f: - keys.extend((line.strip() for line in f)) - - for _, examples in self.forms.items(): - for example in examples: - if example["key"] not in keys: - continue - self.text.append(example["text"].lower()) - - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" - token_to_index = self.graphemes_to_index - if self.lexicon is not None: - if len(line) > 0: - # If the word is not found in the lexicon, fall back to letters. - line = [ - t - for w in line.split(self.wordsep) - for t in self.lexicon.get(w, self.wordsep + w) - ] - token_to_index = self.tokens_to_index - if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) - - def to_text(self, indices: List[int]) -> str: - """Converts indices to text.""" - # Roughly the inverse of `to_index` - encoding = self.graphemes - if self.lexicon is not None: - encoding = self.tokens - return self._post_process(encoding[i] for i in indices) - - def tokens_to_text(self, indices: List[int]) -> str: - """Converts tokens to text.""" - return self._post_process(self.tokens[i] for i in indices) - - def _post_process(self, indices: List[int]) -> str: - """A list join.""" - return "".join(indices).strip(self.wordsep) - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") -@click.option( - "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" -) -@click.option( - "--save_text", type=str, default=None, help="Path to save parsed train text" -) -@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") -def cli( - data_dir: Optional[str], - use_words: bool, - save_text: Optional[str], - save_tokens: Optional[str], -) -> None: - """CLI for extracting text data from the iam dataset.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - preprocessor = Preprocessor(data_dir, 64, use_words=use_words) - preprocessor.extract_train_text() - - processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") - - if save_text is not None: - logger.info("Saving training text") - with open(processed_dir / save_text, "w") as f: - f.write("\n".join(t for t in preprocessor.text)) - - if save_tokens is not None: - logger.info("Saving tokens") - with open(processed_dir / save_tokens, "w") as f: - f.write("\n".join(preprocessor.tokens)) - - -if __name__ == "__main__": - cli() diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py deleted file mode 100644 index 53b781c..0000000 --- a/text_recognizer/datasets/sentence_generator.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Downloading the Brown corpus with NLTK for sentence generating.""" - -import itertools -import re -import string -from typing import Optional - -import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView -import numpy as np - -from text_recognizer.datasets.util import DATA_DIRNAME - -NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" - - -class SentenceGenerator: - """Generates text sentences using the Brown corpus.""" - - def __init__(self, max_length: Optional[int] = None) -> None: - """Loads the corpus and sets word start indices.""" - self.corpus = brown_corpus() - self.word_start_indices = [0] + [ - _.start(0) + 1 for _ in re.finditer(" ", self.corpus) - ] - self.max_length = max_length - - def generate(self, max_length: Optional[int] = None) -> str: - """Generates a word or sentences from the Brown corpus. - - Sample a string from the Brown corpus of length at least one word and at most max_length, padding to - max_length with the '_' characters if sentence is shorter. - - Args: - max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. - - Returns: - str: A sentence from the Brown corpus. - - Raises: - ValueError: If max_length was not specified at initialization and not given as an argument. - - """ - if max_length is None: - max_length = self.max_length - if max_length is None: - raise ValueError( - "Must provide max_length to this method or when making this object." - ) - - for _ in range(10): - try: - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - return sampled_text - except Exception: - pass - raise RuntimeError("Was not able to generate a valid string") - - -def brown_corpus() -> str: - """Returns a single string with the Brown corpus with all punctuations stripped.""" - sentences = load_nltk_brown_corpus() - corpus = " ".join(itertools.chain.from_iterable(sentences)) - corpus = corpus.translate({ord(c): None for c in string.punctuation}) - corpus = re.sub(" +", " ", corpus) - return corpus - - -def load_nltk_brown_corpus() -> ConcatenatedCorpusView: - """Load the Brown corpus using the NLTK library.""" - nltk.data.path.append(NLTK_DATA_DIRNAME) - try: - nltk.corpus.brown.sents() - except LookupError: - NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) - return nltk.corpus.brown.sents() diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py deleted file mode 100644 index b6a48f5..0000000 --- a/text_recognizer/datasets/transforms.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Transforms for PyTorch datasets.""" -from abc import abstractmethod -from pathlib import Path -import random -from typing import Any, Optional, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( - ColorJitter, - Compose, - Normalize, - RandomAffine, - RandomHorizontalFlip, - RandomRotation, - ToPILImage, - ToTensor, -) - -from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: - """Image transform with random resize and crop applied. - - Stolen from - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - - """ - - def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: - self.jitter = jitter - self.ratio = ratio - - def __call__(self, img: np.ndarray) -> np.ndarray: - """Applies random crop and rotation to an image.""" - w, h = img.size - - # pad with white: - img = transforms.functional.pad(img, self.jitter, fill=255) - - # crop at random (x, y): - x = self.jitter + random.randint(-self.jitter, self.jitter) - y = self.jitter + random.randint(-self.jitter, self.jitter) - - # randomize aspect ratio: - size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) - size = (h, int(size_w)) - img = transforms.functional.resized_crop(img, y, x, h, w, size) - return img - - -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - -class Resize: - """Resizes a tensor to a specified width.""" - - def __init__(self, width: int = 952) -> None: - # The default is 952 because of the IAM dataset. - self.width = width - - def __call__(self, image: Tensor) -> Tensor: - """Resize tensor in the last dimension.""" - return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: - """Adds start of sequence and end of sequence tokens to target tensor.""" - - def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, - ) - self.pad_value = self.emnist_mapper(self.pad_token) - self.eos_value = self.emnist_mapper(self.eos_token) - - def __call__(self, target: Tensor) -> Tensor: - """Adds a sos token to the begining and a eos token to the end of a target sequence.""" - dtype, device = target.dtype, target.device - - # Find the where padding starts. - pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - - target[pad_index] = self.eos_value - - if self.init_token is not None: - self.sos_value = self.emnist_mapper(self.init_token) - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) - target = torch.cat([sos, target], dim=0) - - return target - - -class ApplyContrast: - """Sets everything below a threshold to zero, i.e. increase contrast.""" - - def __init__(self, low: float = 0.0, high: float = 0.25) -> None: - self.low = low - self.high = high - - def __call__(self, x: Tensor) -> Tensor: - """Apply mask binary mask to input tensor.""" - mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) - return x * mask - - -class Unsqueeze: - """Add a dimension to the tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Adds dim.""" - return x.unsqueeze(0) - - -class Squeeze: - """Removes the first dimension of a tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Removes first dim.""" - return x.squeeze(0) - - -class ToLower: - """Converts target to lower case.""" - - def __call__(self, target: Tensor) -> Tensor: - """Corrects index value in target tensor.""" - device = target.device - return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) - - -class ToCharcters: - """Converts integers to characters.""" - - def __init__( - self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True - ) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=lower, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, lower=lower - ) - - def __call__(self, y: Tensor) -> str: - """Converts a Tensor to a str.""" - return ( - "".join([self.emnist_mapper(int(i)) for i in y]) - .strip("_") - .replace(" ", "▁") - ) - - -class WordPieces: - """Abstract transform for word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - self.preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - ) - - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """Transforms input.""" - ... - - -class ToWordPieces(WordPieces): - """Transforms str to word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, line: str) -> Tensor: - """Transforms str to word pieces.""" - return self.preprocessor.to_index(line) - - -class ToText(WordPieces): - """Takes word pieces and converts them to text.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, x: Tensor) -> str: - """Converts tensor to text.""" - return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py deleted file mode 100644 index da87756..0000000 --- a/text_recognizer/datasets/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__( - self, - pad_token: str, - init_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - self.lower = lower - - self.essentials = self._load_emnist_essentials() - # Load dataset information. - self._mapping = dict(self.essentials["mapping"]) - self._augment_emnist_mapping() - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if ( - (isinstance(token, np.uint8) or isinstance(token, int)) - or torch.is_tensor(token) - and int(token) in self.mapping - ): - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self) -> None: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - if self.lower: - self._mapping = { - k: str(v) - for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) - } - - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol, and acts as blank symbol as well. - extra_symbols.append(self.pad_token) - - if self.init_token is not None: - extra_symbols.append(self.init_token) - - if self.eos_token is not None: - extra_symbols.append(self.eos_token) - - max_key = max(self.mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: - """Returns the SHA256 checksum of a file.""" - with open(filename, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: - if os.path.exists(metadata["filename"]): - return - logger.info(f"Downloading raw dataset from {metadata['url']}...") - download_url(metadata["url"], metadata["filename"]) - logger.info("Computing SHA-256...") - sha256 = compute_sha256(metadata["filename"]) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) -- cgit v1.2.3-70-g09d2