diff options
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r-- | text_recognizer/datasets/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 89 | ||||
-rw-r--r-- | text_recognizer/datasets/base_dataset.py | 73 | ||||
-rw-r--r-- | text_recognizer/datasets/download_utils.py | 73 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist.py | 210 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_lines.py | 280 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_dataset.py | 133 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_lines_dataset.py | 110 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_paragraphs_dataset.py | 291 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_preprocessor.py | 196 | ||||
-rw-r--r-- | text_recognizer/datasets/sentence_generator.py | 85 | ||||
-rw-r--r-- | text_recognizer/datasets/transforms.py | 266 | ||||
-rw-r--r-- | text_recognizer/datasets/util.py | 209 |
14 files changed, 0 insertions, 2017 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py deleted file mode 100644 index 2727b20..0000000 --- a/text_recognizer/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Dataset modules.""" diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py deleted file mode 100644 index f5e7300..0000000 --- a/text_recognizer/datasets/base_data_module.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Base lightning DataModule class.""" -from pathlib import Path -from typing import Dict - -import pytorch_lightning as pl -from torch.utils.data import DataLoader - - -def load_and_print_info(data_module_class: type) -> None: - """Load EMNISTLines and prints info.""" - dataset = data_module_class() - dataset.prepare_data() - dataset.setup() - print(dataset) - - -class BaseDataModule(pl.LightningDataModule): - """Base PyTorch Lightning DataModule.""" - - def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - - # Placeholders for subclasses. - self.dims = None - self.output_dims = None - self.mapping = None - - @classmethod - def data_dirname(cls) -> Path: - """Return the path to the base data directory.""" - return Path(__file__).resolve().parents[2] / "data" - - def config(self) -> Dict: - """Return important settings of the dataset.""" - return { - "input_dim": self.dims, - "output_dims": self.output_dims, - "mapping": self.mapping, - } - - def prepare_data(self) -> None: - """Prepare data for training.""" - pass - - def setup(self, stage: str = None) -> None: - """Split into train, val, test, and set dims. - - Should assign `torch Dataset` objects to self.data_train, self.data_val, and - optionally self.data_test. - - Args: - stage (Any): Variable to set splits. - - """ - self.data_train = None - self.data_val = None - self.data_test = None - - def train_dataloader(self) -> DataLoader: - """Retun DataLoader for train data.""" - return DataLoader( - self.data_train, - shuffle=True, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: - """Return DataLoader for val data.""" - return DataLoader( - self.data_val, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: - """Return DataLoader for val data.""" - return DataLoader( - self.data_test, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py deleted file mode 100644 index a9e9c24..0000000 --- a/text_recognizer/datasets/base_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union - -import torch -from torch import Tensor -from torch.utils.data import Dataset - - -class BaseDataset(Dataset): - """ - Base Dataset class that processes data and targets through optional transfroms. - - Args: - data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images. - targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays. - tranform (Callable): Function that takes a datum and applies transforms. - target_transform (Callable): Fucntion that takes a target and applies - target transforms. - """ - - def __init__( - self, - data: Union[Sequence, Tensor], - targets: Union[Sequence, Tensor], - transform: Callable = None, - target_transform: Callable = None, - ) -> None: - if len(data) != len(targets): - raise ValueError("Data and targets must be of equal length.") - self.data = data - self.targets = targets - self.transform = transform - self.target_transform = target_transform - - def __len__(self) -> int: - """Return the length of the dataset.""" - return len(self.data) - - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """Return a datum and its target, after processing by transforms. - - Args: - index (int): Index of a datum in the dataset. - - Returns: - Tuple[Any, Any]: Datum and target pair. - - """ - datum, target = self.data[index], self.targets[index] - - if self.transform is not None: - datum = self.transform(datum) - - if self.target_transform is not None: - target = self.target_transform(target) - - return datum, target - - -def convert_strings_to_labels( - strings: Sequence[str], mapping: Dict[str, int], length: int -) -> Tensor: - """ - Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens, - and padded wiht the <p> token. - """ - labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"] - for i, string in enumerate(strings): - tokens = list(string) - tokens = ["<s>", *tokens, "</s>"] - for j, token in enumerate(tokens): - labels[i, j] = mapping[token] - return labels diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py deleted file mode 100644 index e3dc68c..0000000 --- a/text_recognizer/datasets/download_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Util functions for downloading datasets.""" -import hashlib -from pathlib import Path -from typing import Dict, List, Optional -from urllib.request import urlretrieve - -from loguru import logger -from tqdm import tqdm - - -def _compute_sha256(filename: Path) -> str: - """Returns the SHA256 checksum of a file.""" - with filename.open(mode="rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def _download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: - """Downloads dataset using a metadata file. - - Args: - metadata (Dict): A metadata file of the dataset. - dl_dir (Path): Download directory for the dataset. - - Returns: - Optional[Path]: Returns filename if dataset is downloaded, None if it already - exists. - - Raises: - ValueError: If the SHA-256 value is not the same between the dataset and - the metadata file. - - """ - dl_dir.mkdir(parents=True, exist_ok=True) - filename = dl_dir / metadata["filename"] - if filename.exists(): - return - logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") - _download_url(metadata["url"], filename) - logger.info("Computing the SHA-256...") - sha256 = _compute_sha256(filename) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) - return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py deleted file mode 100644 index 66101b5..0000000 --- a/text_recognizer/datasets/emnist.py +++ /dev/null @@ -1,210 +0,0 @@ -"""EMNIST dataset: downloads it from FSDL aws url if not present.""" -from pathlib import Path -from typing import Sequence, Tuple -import json -import os -import shutil -import zipfile - -import h5py -import numpy as np -from loguru import logger -import toml -import torch -from torch.utils.data import random_split -from torchvision import transforms - -from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.download_utils import download_dataset - - -SEED = 4711 -NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" - - -class EMNIST(BaseDataModule): - """ - "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 - and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." - From https://www.nist.gov/itl/iad/image-group/emnist-dataset - - The data split we will use is - EMNIST ByClass: 814,255 characters. 62 unbalanced classes. - """ - - def __init__( - self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 - ) -> None: - super().__init__(batch_size, num_workers) - if not ESSENTIALS_FILENAME.exists(): - _download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) - self.train_fraction = train_fraction - self.mapping = list(essentials["characters"]) - self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} - self.data_train = None - self.data_val = None - self.data_test = None - self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, *essentials["input_shape"]) - self.output_dims = (1,) - - def prepare_data(self) -> None: - if not PROCESSED_DATA_FILENAME.exists(): - _download_and_process_emnist() - - def setup(self, stage: str = None) -> None: - if stage == "fit" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_train = f["x_train"][:] - self.y_train = f["y_train"][:].squeeze().astype(int) - - dataset_train = BaseDataset( - self.x_train, self.y_train, transform=self.transform - ) - train_size = int(self.train_fraction * len(dataset_train)) - val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split( - dataset_train, [train_size, val_size], generator=torch.Generator() - ) - - if stage == "test" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_test = f["x_test"][:] - self.y_test = f["y_test"][:].squeeze().astype(int) - self.data_test = BaseDataset( - self.x_test, self.y_test, transform=self.transform - ) - - def __repr__(self) -> str: - basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - datum, target = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" - f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" - ) - - return basic + data - - -def _download_and_process_emnist() -> None: - metadata = toml.load(METADATA_FILENAME) - download_dataset(metadata, DL_DATA_DIRNAME) - _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) - - -def _process_raw_dataset(filename: str, dirname: Path) -> None: - logger.info("Unzipping EMNIST...") - curdir = os.getcwd() - os.chdir(dirname) - content = zipfile.ZipFile(filename, "r") - content.extract("matlab/emnist-byclass.mat") - - from scipy.io import loadmat - - logger.info("Loading training data from .mat file") - data = loadmat("matlab/emnist-byclass.mat") - x_train = ( - data["dataset"]["train"][0, 0]["images"][0, 0] - .reshape(-1, 28, 28) - .swapaxes(1, 2) - ) - y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = ( - data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) - ) - y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - - if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") - x_train, y_train = _sample_to_balance(x_train, y_train) - x_test, y_test = _sample_to_balance(x_test, y_test) - - logger.info("Saving to HDF5 in a compressed format...") - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: - f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") - f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") - f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") - f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") - mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} - characters = _augment_emnist_characters(mapping.values()) - essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} - - with ESSENTIALS_FILENAME.open(mode="w") as f: - json.dump(essentials, f) - - logger.info("Cleaning up...") - shutil.rmtree("matlab") - os.chdir(curdir) - - -def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Balances the dataset by taking the mean number of instances per class.""" - np.random.seed(SEED) - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - indices = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - return x_sampled, y_sampled - - -def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: - """Augment the mapping with extra symbols.""" - # Extra characters from the IAM dataset. - iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # Also add special tokens for: - # - CTC blank token at index 0 - # - Start token at index 1 - # - End token at index 2 - # - Padding token at index 3 - # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! - return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] - - -def download_emnist() -> None: - """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json deleted file mode 100644 index 3f46a73..0000000 --- a/text_recognizer/datasets/emnist_essentials.json +++ /dev/null @@ -1 +0,0 @@ -{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py deleted file mode 100644 index 9ebad22..0000000 --- a/text_recognizer/datasets/emnist_lines.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Dataset of generated text from EMNIST characters.""" -from collections import defaultdict -from pathlib import Path -from typing import Callable, Dict, Tuple, Sequence - -import h5py -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode - -from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.emnist import EMNIST -from text_recognizer.datasets.sentence_generator import SentenceGenerator - - -DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json" -) - -SEED = 4711 -IMAGE_HEIGHT = 56 -IMAGE_WIDTH = 1024 -IMAGE_X_PADDING = 28 -MAX_OUTPUT_LENGTH = 89 # Same as IAMLines - - -class EMNISTLines(BaseDataModule): - """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,""" - - def __init__( - self, - augment: bool = True, - batch_size: int = 128, - num_workers: int = 0, - max_length: int = 32, - min_overlap: float = 0.0, - max_overlap: float = 0.33, - num_train: int = 10_000, - num_val: int = 2_000, - num_test: int = 2_000, - ) -> None: - super().__init__(batch_size, num_workers) - - self.augment = augment - self.max_length = max_length - self.min_overlap = min_overlap - self.max_overlap = max_overlap - self.num_train = num_train - self.num_val = num_val - self.num_test = num_test - - self.emnist = EMNIST() - self.mapping = self.emnist.mapping - max_width = ( - int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) - + IMAGE_X_PADDING - ) - - if max_width >= IMAGE_WIDTH: - raise ValueError( - f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}" - ) - - self.dims = ( - self.emnist.dims[0], - IMAGE_HEIGHT, - IMAGE_WIDTH - ) - - if self.max_length >= MAX_OUTPUT_LENGTH: - raise ValueError("max_length greater than MAX_OUTPUT_LENGTH") - - self.output_dims = (MAX_OUTPUT_LENGTH, 1) - self.data_train = None - self.data_val = None - self.data_test = None - - @property - def data_filename(self) -> Path: - """Return name of dataset.""" - return ( - DATA_DIRNAME / (f"ml_{self.max_length}_" - f"o{self.min_overlap:f}_{self.max_overlap:f}_" - f"ntr{self.num_train}_" - f"ntv{self.num_val}_" - f"nte{self.num_test}.h5") - ) - - def prepare_data(self) -> None: - if self.data_filename.exists(): - return - np.random.seed(SEED) - self._generate_data("train") - self._generate_data("val") - self._generate_data("test") - - def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") - if stage == "fit" or stage is None: - print(self.data_filename) - with h5py.File(self.data_filename, "r") as f: - x_train = f["x_train"][:] - y_train = torch.LongTensor(f["y_train"][:]) - x_val = f["x_val"][:] - y_val = torch.LongTensor(f["y_val"][:]) - - self.data_train = BaseDataset( - x_train, y_train, transform=_get_transform(augment=self.augment) - ) - self.data_val = BaseDataset( - x_val, y_val, transform=_get_transform(augment=self.augment) - ) - - if stage == "test" or stage is None: - with h5py.File(self.data_filename, "r") as f: - x_test = f["x_test"][:] - y_test = torch.LongTensor(f["y_test"][:]) - - self.data_test = BaseDataset( - x_test, y_test, transform=_get_transform(augment=False) - ) - - def __repr__(self) -> str: - """Return str about dataset.""" - basic = ( - "EMNISTLines2 Dataset\n" # pylint: disable=no-member - f"Min overlap: {self.min_overlap}\n" - f"Max overlap: {self.max_overlap}\n" - f"Num classes: {len(self.mapping)}\n" - f"Dims: {self.dims}\n" - f"Output dims: {self.output_dims}\n" - ) - - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - x, y = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" - f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" - ) - return basic + data - - def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") - sentence_generator = SentenceGenerator( - self.max_length - 2 - ) # Subtract by 2 because start/end token - - emnist = self.emnist - emnist.prepare_data() - emnist.setup() - - if split == "train": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_train - elif split == "val": - samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping - ) - num = self.num_val - else: - samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping - ) - num = self.num_test - - DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(self.data_filename, "a") as f: - x, y = _create_dataset_of_images( - num, - samples_by_char, - sentence_generator, - self.min_overlap, - self.max_overlap, - self.dims, - ) - y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH - ) - f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") - f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") - - -def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict -) -> defaultdict: - samples_by_char = defaultdict(list) - for sample, label in zip(samples, labels): - samples_by_char[mapping[label]].append(sample) - return samples_by_char - - -def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict): - null_image = torch.zeros((28, 28), dtype=torch.uint8) - sample_image_by_char = {} - for char in string: - if char in sample_image_by_char: - continue - samples = samples_by_char[char] - sample = samples[np.random.choice(len(samples))] if samples else null_image - sample_image_by_char[char] = sample.reshape(28, 28) - return [sample_image_by_char[char] for char in string] - - -def _construct_image_from_string( - string: str, - samples_by_char: defaultdict, - min_overlap: float, - max_overlap: float, - width: int, -) -> torch.Tensor: - overlap = np.random.uniform(min_overlap, max_overlap) - sampled_images = _select_letter_samples_for_string(string, samples_by_char) - N = len(sampled_images) - H, W = sampled_images[0].shape - next_overlap_width = W - int(overlap * W) - concatenated_image = torch.zeros((H, width), dtype=torch.uint8) - x = IMAGE_X_PADDING - for image in sampled_images: - concatenated_image[:, x : (x + W)] += image - x += next_overlap_width - return torch.minimum(torch.Tensor([255]), concatenated_image) - - -def _create_dataset_of_images( - num_samples: int, - samples_by_char: defaultdict, - sentence_generator: SentenceGenerator, - min_overlap: float, - max_overlap: float, - dims: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor]: - images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2])) - labels = [] - for n in range(num_samples): - label = sentence_generator.generate() - crop = _construct_image_from_string( - label, samples_by_char, min_overlap, max_overlap, dims[-1] - ) - height = crop.shape[0] - y = (IMAGE_HEIGHT - height) // 2 - images[n, y : (y + height), :] = crop - labels.append(label) - return images, labels - - -def _get_transform(augment: bool = False) -> Callable: - if not augment: - return transforms.Compose([transforms.ToTensor()]) - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.ColorJitter(brightness=(0.5, 1.0)), - transforms.RandomAffine( - degrees=3, - translate=(0.0, 0.05), - scale=(0.4, 1.1), - shear=(-40, 50), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - ) - - -def generate_emnist_lines() -> None: - """Generates a synthetic handwritten dataset and displays info,""" - load_and_print_info(EMNISTLines) diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py deleted file mode 100644 index a8998b9..0000000 --- a/text_recognizer/datasets/iam_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" -import os -from typing import Any, Dict, List -import zipfile - -from boltons.cacheutils import cachedproperty -import defusedxml.ElementTree as ET -from loguru import logger -import toml - -from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME - -RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" -RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - -DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. -LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. - - -class IamDataset: - """IAM dataset. - - "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, - which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." - From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database - - The data split we will use is - IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. - The validation set has been merged into the train set. - The train set has 7,101 lines from 326 writers. - The test set has 1,861 lines from 128 writers. - The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. - - """ - - def __init__(self) -> None: - self.metadata = toml.load(METADATA_FILENAME) - - def load_or_generate_data(self) -> None: - """Downloads IAM dataset if xml files does not exist.""" - if not self.xml_filenames: - self._download_iam() - - @property - def xml_filenames(self) -> List: - """List of xml filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) - - @property - def form_filenames(self) -> List: - """List of forms filenames.""" - return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) - - def _download_iam(self) -> None: - curdir = os.getcwd() - os.chdir(RAW_DATA_DIRNAME) - _download_raw_dataset(self.metadata) - _extract_raw_dataset(self.metadata) - os.chdir(curdir) - - @property - def form_filenames_by_id(self) -> Dict: - """Creates a dictionary with filenames as keys and forms as values.""" - return {filename.stem: filename for filename in self.form_filenames} - - @cachedproperty - def line_strings_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of line texts in it.""" - return { - filename.stem: _get_line_strings_from_xml_file(filename) - for filename in self.xml_filenames - } - - @cachedproperty - def line_regions_by_id(self) -> Dict: - """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return { - filename.stem: _get_line_regions_from_xml_file(filename) - for filename in self.xml_filenames - } - - def __repr__(self) -> str: - """Print info about dataset.""" - return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n" - - -def _extract_raw_dataset(metadata: Dict) -> None: - logger.info("Extracting IAM data.") - with zipfile.ZipFile(metadata["filename"], "r") as zip_file: - zip_file.extractall() - - -def _get_line_strings_from_xml_file(filename: str) -> List[str]: - """Get the text content of each line. Note that we replace " with ".""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] - - -def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: - """Get the line region dict for each line.""" - xml_root_element = ET.parse(filename).getroot() # nosec - xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [_get_line_region_from_xml_element(el) for el in xml_line_elements] - - -def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]: - """Extracts coordinates for each line of text.""" - # TODO: fix input! - word_elements = xml_line.findall("word/cmp") - x1s = [int(el.attrib["x"]) for el in word_elements] - y1s = [int(el.attrib["y"]) for el in word_elements] - x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] - return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } - - -def main() -> None: - """Initializes the dataset and print info about the dataset.""" - dataset = IamDataset() - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py deleted file mode 100644 index 1cb84bd..0000000 --- a/text_recognizer/datasets/iam_lines_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -"""IamLinesDataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import h5py -from loguru import logger -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - - -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5" -PROCESSED_DATA_URL = ( - "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5" -) - - -class IamLinesDataset(Dataset): - """IAM lines datasets for handwritten text lines.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - self.pad_token = "_" if pad_token is None else pad_token - - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - init_token=init_token, - pad_token=pad_token, - eos_token=eos_token, - lower=lower, - ) - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self.data.shape[1:] if self.data is not None else None - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return ( - self.targets.shape[1:] + (self.num_classes,) - if self.targets is not None - else None - ) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - if not PROCESSED_DATA_FILENAME.exists(): - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info("Downloading IAM lines...") - download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self._data = f[f"x_{self.split}"][:] - self._targets = f[f"y_{self.split}"][:] - self._subsample() - - def __repr__(self) -> str: - """Print info about the dataset.""" - return ( - "IAM Lines Dataset\n" # pylint: disable=no-member - f"Number classes: {self.num_classes}\n" - f"Mapping: {self.mapper.mapping}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py deleted file mode 100644 index 8ba5142..0000000 --- a/text_recognizer/datasets/iam_paragraphs_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -"""IamParagraphsDataset class and functions for data processing.""" -import random -from typing import Callable, Dict, List, Optional, Tuple, Union - -import click -import cv2 -import h5py -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.transforms import ToTensor - -from text_recognizer import util -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.iam_dataset import IamDataset -from text_recognizer.datasets.util import ( - compute_sha256, - DATA_DIRNAME, - download_url, - EmnistMapper, -) - -INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs" -DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops" -PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs" -CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops" -GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt" - -PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines. -TEST_FRACTION = 0.2 -SEED = 4711 - - -class IamParagraphsDataset(Dataset): - """IAM Paragraphs dataset for paragraphs of handwritten text.""" - - def __init__( - self, - train: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - ) -> None: - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - ) - # Load Iam dataset. - self.iam_dataset = IamDataset() - - self.num_classes = 3 - self._input_shape = (256, 256) - self._output_shape = self._input_shape + (self.num_classes,) - self._ids = None - - def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]: - """Fetches data, target pair of the dataset for a given and index or indices. - - Args: - index (Union[int, Tensor]): Either a list or int of indices/index. - - Returns: - Tuple[Tensor, Tensor]: Data target pair. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - seed = np.random.randint(SEED) - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.transform: - data = self.transform(data) - - random.seed(seed) # apply this seed to target tranfsorms - torch.manual_seed(seed) # needed for torchvision 0.7 - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets.long() - - @property - def ids(self) -> Tensor: - """Ids of the dataset.""" - return self._ids - - def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]: - """Get data target pair from id.""" - ind = self.ids.index(id_) - return self.data[ind], self.targets[ind] - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - num_actual = len(list(CROPS_DIRNAME.glob("*.jpg"))) - num_targets = len(self.iam_dataset.line_regions_by_id) - - if num_actual < num_targets - 2: - self._process_iam_paragraphs() - - self._data, self._targets, self._ids = _load_iam_paragraphs() - self._get_random_split() - self._subsample() - - def _get_random_split(self) -> None: - np.random.seed(SEED) - num_train = int((1 - TEST_FRACTION) * self.data.shape[0]) - indices = np.random.permutation(self.data.shape[0]) - train_indices, test_indices = indices[:num_train], indices[num_train:] - if self.train: - self._data = self.data[train_indices] - self._targets = self.targets[train_indices] - else: - self._data = self.data[test_indices] - self._targets = self.targets[test_indices] - - def _process_iam_paragraphs(self) -> None: - """Crop the part with the text. - - For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are - self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel - corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line - """ - crop_dims = self._decide_on_crop_dims() - CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True) - GT_DIRNAME.mkdir(parents=True, exist_ok=True) - logger.info( - f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}" - ) - for filename in self.iam_dataset.form_filenames: - id_ = filename.stem - line_region = self.iam_dataset.line_regions_by_id[id_] - _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape) - - def _decide_on_crop_dims(self) -> Tuple[int, int]: - """Decide on the dimensions to crop out of the form image. - - Since image width is larger than a comfortable crop around the longest paragraph, - we will make the crop a square form factor. - And since the found dimensions 610x610 are pretty close to 512x512, - we might as well resize crops and make it exactly that, which lets us - do all kinds of power-of-2 pooling and upsampling should we choose to. - - Returns: - Tuple[int, int]: A tuple of crop dimensions. - - Raises: - RuntimeError: When max crop height is larger than max crop width. - - """ - - sample_form_filename = self.iam_dataset.form_filenames[0] - sample_image = util.read_image(sample_form_filename, grayscale=True) - max_crop_width = sample_image.shape[1] - max_crop_height = _get_max_paragraph_crop_height( - self.iam_dataset.line_regions_by_id - ) - if not max_crop_height <= max_crop_width: - raise RuntimeError( - f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}" - ) - - crop_dims = (max_crop_width, max_crop_width) - logger.info( - f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}." - ) - logger.info(f"Setting them to {max_crop_width}x{max_crop_width}") - return crop_dims - - def __repr__(self) -> str: - """Return info about the dataset.""" - return ( - "IAM Paragraph Dataset\n" # pylint: disable=no-member - f"Num classes: {self.num_classes}\n" - f"Data: {self.data.shape}\n" - f"Targets: {self.targets.shape}\n" - ) - - -def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int: - heights = [] - for regions in line_regions_by_id.values(): - min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - heights.append(height) - return max(heights) - - -def _crop_paragraph_image( - filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple -) -> None: - image = util.read_image(filename, grayscale=True) - - min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER - max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER - height = max_y2 - min_y1 - crop_height = crop_dims[0] - buffer = (crop_height - height) // 2 - - # Generate image crop. - image_crop = 255 * np.ones(crop_dims, dtype=np.uint8) - try: - image_crop[buffer : buffer + height] = image[min_y1:max_y2] - except Exception as e: # pylint: disable=broad-except - logger.error(f"Rescued {filename}: {e}") - return - - # Generate ground truth. - gt_image = np.zeros_like(image_crop, dtype=np.uint8) - for index, region in enumerate(line_regions): - gt_image[ - (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer), - region["x1"] : region["x2"], - ] = (index % 2 + 1) - - # Generate image for debugging. - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("Set1") - image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop]) - for index, region in enumerate(line_regions): - color = [255 * _ for _ in cmap(index)[:-1]] - cv2.rectangle( - image_crop_for_debug, - (region["x1"], region["y1"] - min_y1 + buffer), - (region["x2"], region["y2"] - min_y1 + buffer), - color, - 3, - ) - image_crop_for_debug = cv2.resize( - image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA - ) - util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg") - - image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA) - util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg") - - gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST) - util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png") - - -def _load_iam_paragraphs() -> None: - logger.info("Loading IAM paragraph crops and ground truth from image files...") - images = [] - gt_images = [] - ids = [] - for filename in CROPS_DIRNAME.glob("*.jpg"): - id_ = filename.stem - image = util.read_image(filename, grayscale=True) - image = 1.0 - image / 255 - - gt_filename = GT_DIRNAME / f"{id_}.png" - gt_image = util.read_image(gt_filename, grayscale=True) - - images.append(image) - gt_images.append(gt_image) - ids.append(id_) - images = np.array(images).astype(np.float32) - gt_images = np.array(gt_images).astype(np.uint8) - ids = np.array(ids) - return images, gt_images, ids - - -@click.command() -@click.option( - "--subsample_fraction", - type=float, - default=None, - help="The subsampling factor of the dataset.", -) -def main(subsample_fraction: float) -> None: - """Load dataset and print info.""" - logger.info("Creating train set...") - dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - logger.info("Creating test set...") - dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) - dataset.load_or_generate_data() - print(dataset) - - -if __name__ == "__main__": - main() diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py deleted file mode 100644 index a93eb00..0000000 --- a/text_recognizer/datasets/iam_preprocessor.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Preprocessor for extracting word letters from the IAM dataset. - -The code is mostly stolen from: - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - -""" - -import collections -import itertools -from pathlib import Path -import re -from typing import List, Optional, Union - -import click -from loguru import logger -import torch - - -def load_metadata( - data_dir: Path, wordsep: str, use_words: bool = False -) -> collections.defaultdict: - """Loads IAM metadata and returns it as a dictionary.""" - forms = collections.defaultdict(list) - filename = "words.txt" if use_words else "lines.txt" - - with open(data_dir / "ascii" / filename, "r") as f: - lines = (line.strip().split() for line in f if line[0] != "#") - for line in lines: - # Skip word segmentation errors. - if use_words and line[1] == "err": - continue - text = " ".join(line[8:]) - - # Remove garbage tokens: - text = text.replace("#", "") - - # Swap word sep form | to wordsep - text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep) - form_key = "-".join(line[0].split("-")[:2]) - line_key = "-".join(line[0].split("-")[:3]) - box_idx = 4 - use_words - box = tuple(int(val) for val in line[box_idx : box_idx + 4]) - forms[form_key].append({"key": line_key, "box": box, "text": text}) - return forms - - -class Preprocessor: - """A preprocessor for the IAM dataset.""" - - # TODO: add lower case only to when generating... - - def __init__( - self, - data_dir: Union[str, Path], - num_features: int, - tokens_path: Optional[Union[str, Path]] = None, - lexicon_path: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - self.wordsep = "▁" - self._use_word = use_words - self._prepend_wordsep = prepend_wordsep - - self.data_dir = Path(data_dir) - - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) - - # Load the set of graphemes: - graphemes = set() - for _, form in self.forms.items(): - for line in form: - graphemes.update(line["text"].lower()) - self.graphemes = sorted(graphemes) - - # Build the token-to-index and index-to-token maps. - if tokens_path is not None: - with open(tokens_path, "r") as f: - self.tokens = [line.strip() for line in f] - else: - self.tokens = self.graphemes - - if lexicon_path is not None: - with open(lexicon_path, "r") as f: - lexicon = (line.strip().split() for line in f) - lexicon = {line[0]: line[1:] for line in lexicon} - self.lexicon = lexicon - else: - self.lexicon = None - - self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} - self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} - self.num_features = num_features - self.text = [] - - @property - def num_tokens(self) -> int: - """Returns the number or tokens.""" - return len(self.tokens) - - @property - def use_words(self) -> bool: - """If words are used.""" - return self._use_word - - def extract_train_text(self) -> None: - """Extracts training text.""" - keys = [] - with open(self.data_dir / "task" / "trainset.txt") as f: - keys.extend((line.strip() for line in f)) - - for _, examples in self.forms.items(): - for example in examples: - if example["key"] not in keys: - continue - self.text.append(example["text"].lower()) - - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" - token_to_index = self.graphemes_to_index - if self.lexicon is not None: - if len(line) > 0: - # If the word is not found in the lexicon, fall back to letters. - line = [ - t - for w in line.split(self.wordsep) - for t in self.lexicon.get(w, self.wordsep + w) - ] - token_to_index = self.tokens_to_index - if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) - - def to_text(self, indices: List[int]) -> str: - """Converts indices to text.""" - # Roughly the inverse of `to_index` - encoding = self.graphemes - if self.lexicon is not None: - encoding = self.tokens - return self._post_process(encoding[i] for i in indices) - - def tokens_to_text(self, indices: List[int]) -> str: - """Converts tokens to text.""" - return self._post_process(self.tokens[i] for i in indices) - - def _post_process(self, indices: List[int]) -> str: - """A list join.""" - return "".join(indices).strip(self.wordsep) - - -@click.command() -@click.option("--data_dir", type=str, default=None, help="Path to iam dataset") -@click.option( - "--use_words", is_flag=True, help="Load word segmented dataset instead of lines" -) -@click.option( - "--save_text", type=str, default=None, help="Path to save parsed train text" -) -@click.option("--save_tokens", type=str, default=None, help="Path to save tokens") -def cli( - data_dir: Optional[str], - use_words: bool, - save_text: Optional[str], - save_tokens: Optional[str], -) -> None: - """CLI for extracting text data from the iam dataset.""" - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - - preprocessor = Preprocessor(data_dir, 64, use_words=use_words) - preprocessor.extract_train_text() - - processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") - - if save_text is not None: - logger.info("Saving training text") - with open(processed_dir / save_text, "w") as f: - f.write("\n".join(t for t in preprocessor.text)) - - if save_tokens is not None: - logger.info("Saving tokens") - with open(processed_dir / save_tokens, "w") as f: - f.write("\n".join(preprocessor.tokens)) - - -if __name__ == "__main__": - cli() diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py deleted file mode 100644 index 53b781c..0000000 --- a/text_recognizer/datasets/sentence_generator.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Downloading the Brown corpus with NLTK for sentence generating.""" - -import itertools -import re -import string -from typing import Optional - -import nltk -from nltk.corpus.reader.util import ConcatenatedCorpusView -import numpy as np - -from text_recognizer.datasets.util import DATA_DIRNAME - -NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk" - - -class SentenceGenerator: - """Generates text sentences using the Brown corpus.""" - - def __init__(self, max_length: Optional[int] = None) -> None: - """Loads the corpus and sets word start indices.""" - self.corpus = brown_corpus() - self.word_start_indices = [0] + [ - _.start(0) + 1 for _ in re.finditer(" ", self.corpus) - ] - self.max_length = max_length - - def generate(self, max_length: Optional[int] = None) -> str: - """Generates a word or sentences from the Brown corpus. - - Sample a string from the Brown corpus of length at least one word and at most max_length, padding to - max_length with the '_' characters if sentence is shorter. - - Args: - max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None. - - Returns: - str: A sentence from the Brown corpus. - - Raises: - ValueError: If max_length was not specified at initialization and not given as an argument. - - """ - if max_length is None: - max_length = self.max_length - if max_length is None: - raise ValueError( - "Must provide max_length to this method or when making this object." - ) - - for _ in range(10): - try: - index = np.random.randint(0, len(self.word_start_indices) - 1) - start_index = self.word_start_indices[index] - end_index_candidates = [] - for index in range(index + 1, len(self.word_start_indices)): - if self.word_start_indices[index] - start_index > max_length: - break - end_index_candidates.append(self.word_start_indices[index]) - end_index = np.random.choice(end_index_candidates) - sampled_text = self.corpus[start_index:end_index].strip() - return sampled_text - except Exception: - pass - raise RuntimeError("Was not able to generate a valid string") - - -def brown_corpus() -> str: - """Returns a single string with the Brown corpus with all punctuations stripped.""" - sentences = load_nltk_brown_corpus() - corpus = " ".join(itertools.chain.from_iterable(sentences)) - corpus = corpus.translate({ord(c): None for c in string.punctuation}) - corpus = re.sub(" +", " ", corpus) - return corpus - - -def load_nltk_brown_corpus() -> ConcatenatedCorpusView: - """Load the Brown corpus using the NLTK library.""" - nltk.data.path.append(NLTK_DATA_DIRNAME) - try: - nltk.corpus.brown.sents() - except LookupError: - NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) - return nltk.corpus.brown.sents() diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py deleted file mode 100644 index b6a48f5..0000000 --- a/text_recognizer/datasets/transforms.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Transforms for PyTorch datasets.""" -from abc import abstractmethod -from pathlib import Path -import random -from typing import Any, Optional, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -import torch.nn.functional as F -from torchvision import transforms -from torchvision.transforms import ( - ColorJitter, - Compose, - Normalize, - RandomAffine, - RandomHorizontalFlip, - RandomRotation, - ToPILImage, - ToTensor, -) - -from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.datasets.util import EmnistMapper - - -class RandomResizeCrop: - """Image transform with random resize and crop applied. - - Stolen from - - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - - """ - - def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: - self.jitter = jitter - self.ratio = ratio - - def __call__(self, img: np.ndarray) -> np.ndarray: - """Applies random crop and rotation to an image.""" - w, h = img.size - - # pad with white: - img = transforms.functional.pad(img, self.jitter, fill=255) - - # crop at random (x, y): - x = self.jitter + random.randint(-self.jitter, self.jitter) - y = self.jitter + random.randint(-self.jitter, self.jitter) - - # randomize aspect ratio: - size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) - size = (h, int(size_w)) - img = transforms.functional.resized_crop(img, y, x, h, w, size) - return img - - -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - -class Resize: - """Resizes a tensor to a specified width.""" - - def __init__(self, width: int = 952) -> None: - # The default is 952 because of the IAM dataset. - self.width = width - - def __call__(self, image: Tensor) -> Tensor: - """Resize tensor in the last dimension.""" - return F.interpolate(image, size=self.width, mode="nearest") - - -class AddTokens: - """Adds start of sequence and end of sequence tokens to target tensor.""" - - def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, - ) - self.pad_value = self.emnist_mapper(self.pad_token) - self.eos_value = self.emnist_mapper(self.eos_token) - - def __call__(self, target: Tensor) -> Tensor: - """Adds a sos token to the begining and a eos token to the end of a target sequence.""" - dtype, device = target.dtype, target.device - - # Find the where padding starts. - pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() - - target[pad_index] = self.eos_value - - if self.init_token is not None: - self.sos_value = self.emnist_mapper(self.init_token) - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) - target = torch.cat([sos, target], dim=0) - - return target - - -class ApplyContrast: - """Sets everything below a threshold to zero, i.e. increase contrast.""" - - def __init__(self, low: float = 0.0, high: float = 0.25) -> None: - self.low = low - self.high = high - - def __call__(self, x: Tensor) -> Tensor: - """Apply mask binary mask to input tensor.""" - mask = x > np.random.RandomState().uniform(low=self.low, high=self.high) - return x * mask - - -class Unsqueeze: - """Add a dimension to the tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Adds dim.""" - return x.unsqueeze(0) - - -class Squeeze: - """Removes the first dimension of a tensor.""" - - def __call__(self, x: Tensor) -> Tensor: - """Removes first dim.""" - return x.squeeze(0) - - -class ToLower: - """Converts target to lower case.""" - - def __call__(self, target: Tensor) -> Tensor: - """Corrects index value in target tensor.""" - device = target.device - return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) - - -class ToCharcters: - """Converts integers to characters.""" - - def __init__( - self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True - ) -> None: - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - if self.init_token is not None: - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - lower=lower, - ) - else: - self.emnist_mapper = EmnistMapper( - pad_token=self.pad_token, eos_token=self.eos_token, lower=lower - ) - - def __call__(self, y: Tensor) -> str: - """Converts a Tensor to a str.""" - return ( - "".join([self.emnist_mapper(int(i)) for i in y]) - .strip("_") - .replace(" ", "▁") - ) - - -class WordPieces: - """Abstract transform for word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - self.preprocessor = Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - ) - - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """Transforms input.""" - ... - - -class ToWordPieces(WordPieces): - """Transforms str to word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, line: str) -> Tensor: - """Transforms str to word pieces.""" - return self.preprocessor.to_index(line) - - -class ToText(WordPieces): - """Takes word pieces and converts them to text.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, x: Tensor) -> str: - """Converts tensor to text.""" - return self.preprocessor.to_text(x.tolist()) diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py deleted file mode 100644 index da87756..0000000 --- a/text_recognizer/datasets/util.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Util functions for datasets.""" -import hashlib -import json -import os -from pathlib import Path -import string -from typing import Dict, List, Optional, Union -from urllib.request import urlretrieve - -from loguru import logger -import numpy as np -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from tqdm import tqdm - -DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" -ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" - - -def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: - """Extract and saves EMNIST essentials.""" - labels = emnsit_dataset.classes - labels.sort() - mapping = [(i, str(label)) for i, label in enumerate(labels)] - essentials = { - "mapping": mapping, - "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), - } - logger.info("Saving emnist essentials...") - with open(ESSENTIALS_FILENAME, "w") as f: - json.dump(essentials, f) - - -def download_emnist() -> None: - """Download the EMNIST dataset via the PyTorch class.""" - logger.info(f"Data directory is: {DATA_DIRNAME}") - dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) - save_emnist_essentials(dataset) - - -class EmnistMapper: - """Mapper between network output to Emnist character.""" - - def __init__( - self, - pad_token: str, - init_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Loads the emnist essentials file with the mapping and input shape.""" - self.init_token = init_token - self.pad_token = pad_token - self.eos_token = eos_token - self.lower = lower - - self.essentials = self._load_emnist_essentials() - # Load dataset information. - self._mapping = dict(self.essentials["mapping"]) - self._augment_emnist_mapping() - self._inverse_mapping = {v: k for k, v in self.mapping.items()} - self._num_classes = len(self.mapping) - self._input_shape = self.essentials["input_shape"] - - def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: - """Maps the token to emnist character or character index. - - If the token is an integer (index), the method will return the Emnist character corresponding to that index. - If the token is a str (Emnist character), the method will return the corresponding index for that character. - - Args: - token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). - - Returns: - Union[str, int]: The mapping result. - - Raises: - KeyError: If the index or string does not exist in the mapping. - - """ - if ( - (isinstance(token, np.uint8) or isinstance(token, int)) - or torch.is_tensor(token) - and int(token) in self.mapping - ): - return self.mapping[int(token)] - elif isinstance(token, str) and token in self._inverse_mapping: - return self._inverse_mapping[token] - else: - raise KeyError(f"Token {token} does not exist in the mappings.") - - @property - def mapping(self) -> Dict: - """Returns the mapping between index and character.""" - return self._mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the mapping between character and index.""" - return self._inverse_mapping - - @property - def num_classes(self) -> int: - """Returns the number of classes in the dataset.""" - return self._num_classes - - @property - def input_shape(self) -> List[int]: - """Returns the input shape of the Emnist characters.""" - return self._input_shape - - def _load_emnist_essentials(self) -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - def _augment_emnist_mapping(self) -> None: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - if self.lower: - self._mapping = { - k: str(v) - for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) - } - - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol, and acts as blank symbol as well. - extra_symbols.append(self.pad_token) - - if self.init_token is not None: - extra_symbols.append(self.init_token) - - if self.eos_token is not None: - extra_symbols.append(self.eos_token) - - max_key = max(self.mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - self._mapping = {**self.mapping, **extra_mapping} - - -def compute_sha256(filename: Union[Path, str]) -> str: - """Returns the SHA256 checksum of a file.""" - with open(filename, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() - - -class TqdmUpTo(tqdm): - """TQDM progress bar when downloading files. - - From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py - - """ - - def update_to( - self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None - ) -> None: - """Updates the progress bar. - - Args: - blocks (int): Number of blocks transferred so far. Defaults to 1. - block_size (int): Size of each block, in tqdm units. Defaults to 1. - total_size (Optional[int]): Total size in tqdm units. Defaults to None. - """ - if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init - self.update(blocks * block_size - self.n) - - -def download_url(url: str, filename: str) -> None: - """Downloads a file from url to filename, with a progress bar.""" - with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: - urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec - - -def _download_raw_dataset(metadata: Dict) -> None: - if os.path.exists(metadata["filename"]): - return - logger.info(f"Downloading raw dataset from {metadata['url']}...") - download_url(metadata["url"], metadata["filename"]) - logger.info("Computing SHA-256...") - sha256 = compute_sha256(metadata["filename"]) - if sha256 != metadata["sha256"]: - raise ValueError( - "Downloaded data file SHA-256 does not match that listed in metadata document." - ) |