summaryrefslogtreecommitdiff
path: root/text_recognizer/datasets
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-24 22:15:54 +0100
commit8248f173132dfb7e47ec62b08e9235990c8626e3 (patch)
tree2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/datasets
parent74c907a17379688967dc4b3f41a44ba83034f5e0 (diff)
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/datasets')
-rw-r--r--text_recognizer/datasets/__init__.py1
-rw-r--r--text_recognizer/datasets/base_data_module.py89
-rw-r--r--text_recognizer/datasets/base_dataset.py73
-rw-r--r--text_recognizer/datasets/download_utils.py73
-rw-r--r--text_recognizer/datasets/emnist.py210
-rw-r--r--text_recognizer/datasets/emnist_essentials.json1
-rw-r--r--text_recognizer/datasets/emnist_lines.py280
-rw-r--r--text_recognizer/datasets/iam_dataset.py133
-rw-r--r--text_recognizer/datasets/iam_lines_dataset.py110
-rw-r--r--text_recognizer/datasets/iam_paragraphs_dataset.py291
-rw-r--r--text_recognizer/datasets/iam_preprocessor.py196
-rw-r--r--text_recognizer/datasets/sentence_generator.py85
-rw-r--r--text_recognizer/datasets/transforms.py266
-rw-r--r--text_recognizer/datasets/util.py209
14 files changed, 0 insertions, 2017 deletions
diff --git a/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
deleted file mode 100644
index 2727b20..0000000
--- a/text_recognizer/datasets/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Dataset modules."""
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py
deleted file mode 100644
index f5e7300..0000000
--- a/text_recognizer/datasets/base_data_module.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""Base lightning DataModule class."""
-from pathlib import Path
-from typing import Dict
-
-import pytorch_lightning as pl
-from torch.utils.data import DataLoader
-
-
-def load_and_print_info(data_module_class: type) -> None:
- """Load EMNISTLines and prints info."""
- dataset = data_module_class()
- dataset.prepare_data()
- dataset.setup()
- print(dataset)
-
-
-class BaseDataModule(pl.LightningDataModule):
- """Base PyTorch Lightning DataModule."""
-
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
- super().__init__()
- self.batch_size = batch_size
- self.num_workers = num_workers
-
- # Placeholders for subclasses.
- self.dims = None
- self.output_dims = None
- self.mapping = None
-
- @classmethod
- def data_dirname(cls) -> Path:
- """Return the path to the base data directory."""
- return Path(__file__).resolve().parents[2] / "data"
-
- def config(self) -> Dict:
- """Return important settings of the dataset."""
- return {
- "input_dim": self.dims,
- "output_dims": self.output_dims,
- "mapping": self.mapping,
- }
-
- def prepare_data(self) -> None:
- """Prepare data for training."""
- pass
-
- def setup(self, stage: str = None) -> None:
- """Split into train, val, test, and set dims.
-
- Should assign `torch Dataset` objects to self.data_train, self.data_val, and
- optionally self.data_test.
-
- Args:
- stage (Any): Variable to set splits.
-
- """
- self.data_train = None
- self.data_val = None
- self.data_test = None
-
- def train_dataloader(self) -> DataLoader:
- """Retun DataLoader for train data."""
- return DataLoader(
- self.data_train,
- shuffle=True,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )
-
- def val_dataloader(self) -> DataLoader:
- """Return DataLoader for val data."""
- return DataLoader(
- self.data_val,
- shuffle=False,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )
-
- def test_dataloader(self) -> DataLoader:
- """Return DataLoader for val data."""
- return DataLoader(
- self.data_test,
- shuffle=False,
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- pin_memory=True,
- )
diff --git a/text_recognizer/datasets/base_dataset.py b/text_recognizer/datasets/base_dataset.py
deleted file mode 100644
index a9e9c24..0000000
--- a/text_recognizer/datasets/base_dataset.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Base PyTorch Dataset class."""
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
-
-import torch
-from torch import Tensor
-from torch.utils.data import Dataset
-
-
-class BaseDataset(Dataset):
- """
- Base Dataset class that processes data and targets through optional transfroms.
-
- Args:
- data (Union[Sequence, Tensor]): Torch tensors, numpy arrays, or PIL images.
- targets (Union[Sequence, Tensor]): Torch tensors or numpy arrays.
- tranform (Callable): Function that takes a datum and applies transforms.
- target_transform (Callable): Fucntion that takes a target and applies
- target transforms.
- """
-
- def __init__(
- self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
- if len(data) != len(targets):
- raise ValueError("Data and targets must be of equal length.")
- self.data = data
- self.targets = targets
- self.transform = transform
- self.target_transform = target_transform
-
- def __len__(self) -> int:
- """Return the length of the dataset."""
- return len(self.data)
-
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
- """Return a datum and its target, after processing by transforms.
-
- Args:
- index (int): Index of a datum in the dataset.
-
- Returns:
- Tuple[Any, Any]: Datum and target pair.
-
- """
- datum, target = self.data[index], self.targets[index]
-
- if self.transform is not None:
- datum = self.transform(datum)
-
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return datum, target
-
-
-def convert_strings_to_labels(
- strings: Sequence[str], mapping: Dict[str, int], length: int
-) -> Tensor:
- """
- Convert a sequence of N strings to (N, length) ndarray, with each string wrapped with <s> and </s> tokens,
- and padded wiht the <p> token.
- """
- labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
- for i, string in enumerate(strings):
- tokens = list(string)
- tokens = ["<s>", *tokens, "</s>"]
- for j, token in enumerate(tokens):
- labels[i, j] = mapping[token]
- return labels
diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py
deleted file mode 100644
index e3dc68c..0000000
--- a/text_recognizer/datasets/download_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""Util functions for downloading datasets."""
-import hashlib
-from pathlib import Path
-from typing import Dict, List, Optional
-from urllib.request import urlretrieve
-
-from loguru import logger
-from tqdm import tqdm
-
-
-def _compute_sha256(filename: Path) -> str:
- """Returns the SHA256 checksum of a file."""
- with filename.open(mode="rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
-
-
-class TqdmUpTo(tqdm):
- """TQDM progress bar when downloading files.
-
- From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
-
- """
-
- def update_to(
- self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
- ) -> None:
- """Updates the progress bar.
-
- Args:
- blocks (int): Number of blocks transferred so far. Defaults to 1.
- block_size (int): Size of each block, in tqdm units. Defaults to 1.
- total_size (Optional[int]): Total size in tqdm units. Defaults to None.
- """
- if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
- self.update(blocks * block_size - self.n)
-
-
-def _download_url(url: str, filename: str) -> None:
- """Downloads a file from url to filename, with a progress bar."""
- with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
- urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
-
-
-def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
- """Downloads dataset using a metadata file.
-
- Args:
- metadata (Dict): A metadata file of the dataset.
- dl_dir (Path): Download directory for the dataset.
-
- Returns:
- Optional[Path]: Returns filename if dataset is downloaded, None if it already
- exists.
-
- Raises:
- ValueError: If the SHA-256 value is not the same between the dataset and
- the metadata file.
-
- """
- dl_dir.mkdir(parents=True, exist_ok=True)
- filename = dl_dir / metadata["filename"]
- if filename.exists():
- return
- logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
- _download_url(metadata["url"], filename)
- logger.info("Computing the SHA-256...")
- sha256 = _compute_sha256(filename)
- if sha256 != metadata["sha256"]:
- raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )
- return filename
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py
deleted file mode 100644
index 66101b5..0000000
--- a/text_recognizer/datasets/emnist.py
+++ /dev/null
@@ -1,210 +0,0 @@
-"""EMNIST dataset: downloads it from FSDL aws url if not present."""
-from pathlib import Path
-from typing import Sequence, Tuple
-import json
-import os
-import shutil
-import zipfile
-
-import h5py
-import numpy as np
-from loguru import logger
-import toml
-import torch
-from torch.utils.data import random_split
-from torchvision import transforms
-
-from text_recognizer.datasets.base_dataset import BaseDataset
-from text_recognizer.datasets.base_data_module import (
- BaseDataModule,
- load_and_print_info,
-)
-from text_recognizer.datasets.download_utils import download_dataset
-
-
-SEED = 4711
-NUM_SPECIAL_TOKENS = 4
-SAMPLE_TO_BALANCE = True
-
-RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
-PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
-ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
-
-
-class EMNIST(BaseDataModule):
- """
- "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
- and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
- From https://www.nist.gov/itl/iad/image-group/emnist-dataset
-
- The data split we will use is
- EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
- """
-
- def __init__(
- self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
- ) -> None:
- super().__init__(batch_size, num_workers)
- if not ESSENTIALS_FILENAME.exists():
- _download_and_process_emnist()
- with ESSENTIALS_FILENAME.open() as f:
- essentials = json.load(f)
- self.train_fraction = train_fraction
- self.mapping = list(essentials["characters"])
- self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
- self.data_train = None
- self.data_val = None
- self.data_test = None
- self.transform = transforms.Compose([transforms.ToTensor()])
- self.dims = (1, *essentials["input_shape"])
- self.output_dims = (1,)
-
- def prepare_data(self) -> None:
- if not PROCESSED_DATA_FILENAME.exists():
- _download_and_process_emnist()
-
- def setup(self, stage: str = None) -> None:
- if stage == "fit" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self.x_train = f["x_train"][:]
- self.y_train = f["y_train"][:].squeeze().astype(int)
-
- dataset_train = BaseDataset(
- self.x_train, self.y_train, transform=self.transform
- )
- train_size = int(self.train_fraction * len(dataset_train))
- val_size = len(dataset_train) - train_size
- self.data_train, self.data_val = random_split(
- dataset_train, [train_size, val_size], generator=torch.Generator()
- )
-
- if stage == "test" or stage is None:
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self.x_test = f["x_test"][:]
- self.y_test = f["y_test"][:].squeeze().astype(int)
- self.data_test = BaseDataset(
- self.x_test, self.y_test, transform=self.transform
- )
-
- def __repr__(self) -> str:
- basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
- if not any([self.data_train, self.data_val, self.data_test]):
- return basic
-
- datum, target = next(iter(self.train_dataloader()))
- data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
- f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n"
- f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n"
- )
-
- return basic + data
-
-
-def _download_and_process_emnist() -> None:
- metadata = toml.load(METADATA_FILENAME)
- download_dataset(metadata, DL_DATA_DIRNAME)
- _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
-
-
-def _process_raw_dataset(filename: str, dirname: Path) -> None:
- logger.info("Unzipping EMNIST...")
- curdir = os.getcwd()
- os.chdir(dirname)
- content = zipfile.ZipFile(filename, "r")
- content.extract("matlab/emnist-byclass.mat")
-
- from scipy.io import loadmat
-
- logger.info("Loading training data from .mat file")
- data = loadmat("matlab/emnist-byclass.mat")
- x_train = (
- data["dataset"]["train"][0, 0]["images"][0, 0]
- .reshape(-1, 28, 28)
- .swapaxes(1, 2)
- )
- y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
- x_test = (
- data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
- )
- y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
-
- if SAMPLE_TO_BALANCE:
- logger.info("Balancing classes to reduce amount of data")
- x_train, y_train = _sample_to_balance(x_train, y_train)
- x_test, y_test = _sample_to_balance(x_test, y_test)
-
- logger.info("Saving to HDF5 in a compressed format...")
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
- f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
- f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
- f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
- f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
-
- logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
- mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
- characters = _augment_emnist_characters(mapping.values())
- essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
-
- with ESSENTIALS_FILENAME.open(mode="w") as f:
- json.dump(essentials, f)
-
- logger.info("Cleaning up...")
- shutil.rmtree("matlab")
- os.chdir(curdir)
-
-
-def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
- """Balances the dataset by taking the mean number of instances per class."""
- np.random.seed(SEED)
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_indices = []
- for label in np.unique(y.flatten()):
- indices = np.where(y == label)[0]
- sampled_indices = np.unique(np.random.choice(indices, num_to_sample))
- all_sampled_indices.append(sampled_indices)
- indices = np.concatenate(all_sampled_indices)
- x_sampled = x[indices]
- y_sampled = y[indices]
- return x_sampled, y_sampled
-
-
-def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
- """Augment the mapping with extra symbols."""
- # Extra characters from the IAM dataset.
- iam_characters = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # Also add special tokens for:
- # - CTC blank token at index 0
- # - Start token at index 1
- # - End token at index 2
- # - Padding token at index 3
- # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this!
- return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
-
-
-def download_emnist() -> None:
- """Download dataset from internet, if it does not exists, and displays info."""
- load_and_print_info(EMNIST)
diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
deleted file mode 100644
index 3f46a73..0000000
--- a/text_recognizer/datasets/emnist_essentials.json
+++ /dev/null
@@ -1 +0,0 @@
-{"characters": ["<b>", "<s>", "</s>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
diff --git a/text_recognizer/datasets/emnist_lines.py b/text_recognizer/datasets/emnist_lines.py
deleted file mode 100644
index 9ebad22..0000000
--- a/text_recognizer/datasets/emnist_lines.py
+++ /dev/null
@@ -1,280 +0,0 @@
-"""Dataset of generated text from EMNIST characters."""
-from collections import defaultdict
-from pathlib import Path
-from typing import Callable, Dict, Tuple, Sequence
-
-import h5py
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
-
-from text_recognizer.datasets.base_dataset import BaseDataset, convert_strings_to_labels
-from text_recognizer.datasets.base_data_module import (
- BaseDataModule,
- load_and_print_info,
-)
-from text_recognizer.datasets.emnist import EMNIST
-from text_recognizer.datasets.sentence_generator import SentenceGenerator
-
-
-DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist_lines"
-ESSENTIALS_FILENAME = (
- Path(__file__).parents[0].resolve() / "emnist_lines_essentials.json"
-)
-
-SEED = 4711
-IMAGE_HEIGHT = 56
-IMAGE_WIDTH = 1024
-IMAGE_X_PADDING = 28
-MAX_OUTPUT_LENGTH = 89 # Same as IAMLines
-
-
-class EMNISTLines(BaseDataModule):
- """EMNIST Lines dataset: synthetic handwritten lines dataset made from EMNIST,"""
-
- def __init__(
- self,
- augment: bool = True,
- batch_size: int = 128,
- num_workers: int = 0,
- max_length: int = 32,
- min_overlap: float = 0.0,
- max_overlap: float = 0.33,
- num_train: int = 10_000,
- num_val: int = 2_000,
- num_test: int = 2_000,
- ) -> None:
- super().__init__(batch_size, num_workers)
-
- self.augment = augment
- self.max_length = max_length
- self.min_overlap = min_overlap
- self.max_overlap = max_overlap
- self.num_train = num_train
- self.num_val = num_val
- self.num_test = num_test
-
- self.emnist = EMNIST()
- self.mapping = self.emnist.mapping
- max_width = (
- int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
- + IMAGE_X_PADDING
- )
-
- if max_width >= IMAGE_WIDTH:
- raise ValueError(
- f"max_width {max_width} greater than IMAGE_WIDTH {IMAGE_WIDTH}"
- )
-
- self.dims = (
- self.emnist.dims[0],
- IMAGE_HEIGHT,
- IMAGE_WIDTH
- )
-
- if self.max_length >= MAX_OUTPUT_LENGTH:
- raise ValueError("max_length greater than MAX_OUTPUT_LENGTH")
-
- self.output_dims = (MAX_OUTPUT_LENGTH, 1)
- self.data_train = None
- self.data_val = None
- self.data_test = None
-
- @property
- def data_filename(self) -> Path:
- """Return name of dataset."""
- return (
- DATA_DIRNAME / (f"ml_{self.max_length}_"
- f"o{self.min_overlap:f}_{self.max_overlap:f}_"
- f"ntr{self.num_train}_"
- f"ntv{self.num_val}_"
- f"nte{self.num_test}.h5")
- )
-
- def prepare_data(self) -> None:
- if self.data_filename.exists():
- return
- np.random.seed(SEED)
- self._generate_data("train")
- self._generate_data("val")
- self._generate_data("test")
-
- def setup(self, stage: str = None) -> None:
- logger.info("EMNISTLinesDataset loading data from HDF5...")
- if stage == "fit" or stage is None:
- print(self.data_filename)
- with h5py.File(self.data_filename, "r") as f:
- x_train = f["x_train"][:]
- y_train = torch.LongTensor(f["y_train"][:])
- x_val = f["x_val"][:]
- y_val = torch.LongTensor(f["y_val"][:])
-
- self.data_train = BaseDataset(
- x_train, y_train, transform=_get_transform(augment=self.augment)
- )
- self.data_val = BaseDataset(
- x_val, y_val, transform=_get_transform(augment=self.augment)
- )
-
- if stage == "test" or stage is None:
- with h5py.File(self.data_filename, "r") as f:
- x_test = f["x_test"][:]
- y_test = torch.LongTensor(f["y_test"][:])
-
- self.data_test = BaseDataset(
- x_test, y_test, transform=_get_transform(augment=False)
- )
-
- def __repr__(self) -> str:
- """Return str about dataset."""
- basic = (
- "EMNISTLines2 Dataset\n" # pylint: disable=no-member
- f"Min overlap: {self.min_overlap}\n"
- f"Max overlap: {self.max_overlap}\n"
- f"Num classes: {len(self.mapping)}\n"
- f"Dims: {self.dims}\n"
- f"Output dims: {self.output_dims}\n"
- )
-
- if not any([self.data_train, self.data_val, self.data_test]):
- return basic
-
- x, y = next(iter(self.train_dataloader()))
- data = (
- f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
- f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
- f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
- )
- return basic + data
-
- def _generate_data(self, split: str) -> None:
- logger.info(f"EMNISTLines generating data for {split}...")
- sentence_generator = SentenceGenerator(
- self.max_length - 2
- ) # Subtract by 2 because start/end token
-
- emnist = self.emnist
- emnist.prepare_data()
- emnist.setup()
-
- if split == "train":
- samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
- )
- num = self.num_train
- elif split == "val":
- samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
- )
- num = self.num_val
- else:
- samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, emnist.mapping
- )
- num = self.num_test
-
- DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- with h5py.File(self.data_filename, "a") as f:
- x, y = _create_dataset_of_images(
- num,
- samples_by_char,
- sentence_generator,
- self.min_overlap,
- self.max_overlap,
- self.dims,
- )
- y = convert_strings_to_labels(
- y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
- )
- f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
- f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
-
-
-def _get_samples_by_char(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
-) -> defaultdict:
- samples_by_char = defaultdict(list)
- for sample, label in zip(samples, labels):
- samples_by_char[mapping[label]].append(sample)
- return samples_by_char
-
-
-def _select_letter_samples_for_string(string: str, samples_by_char: defaultdict):
- null_image = torch.zeros((28, 28), dtype=torch.uint8)
- sample_image_by_char = {}
- for char in string:
- if char in sample_image_by_char:
- continue
- samples = samples_by_char[char]
- sample = samples[np.random.choice(len(samples))] if samples else null_image
- sample_image_by_char[char] = sample.reshape(28, 28)
- return [sample_image_by_char[char] for char in string]
-
-
-def _construct_image_from_string(
- string: str,
- samples_by_char: defaultdict,
- min_overlap: float,
- max_overlap: float,
- width: int,
-) -> torch.Tensor:
- overlap = np.random.uniform(min_overlap, max_overlap)
- sampled_images = _select_letter_samples_for_string(string, samples_by_char)
- N = len(sampled_images)
- H, W = sampled_images[0].shape
- next_overlap_width = W - int(overlap * W)
- concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
- x = IMAGE_X_PADDING
- for image in sampled_images:
- concatenated_image[:, x : (x + W)] += image
- x += next_overlap_width
- return torch.minimum(torch.Tensor([255]), concatenated_image)
-
-
-def _create_dataset_of_images(
- num_samples: int,
- samples_by_char: defaultdict,
- sentence_generator: SentenceGenerator,
- min_overlap: float,
- max_overlap: float,
- dims: Tuple,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- images = torch.zeros((num_samples, IMAGE_HEIGHT, dims[2]))
- labels = []
- for n in range(num_samples):
- label = sentence_generator.generate()
- crop = _construct_image_from_string(
- label, samples_by_char, min_overlap, max_overlap, dims[-1]
- )
- height = crop.shape[0]
- y = (IMAGE_HEIGHT - height) // 2
- images[n, y : (y + height), :] = crop
- labels.append(label)
- return images, labels
-
-
-def _get_transform(augment: bool = False) -> Callable:
- if not augment:
- return transforms.Compose([transforms.ToTensor()])
- return transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.ColorJitter(brightness=(0.5, 1.0)),
- transforms.RandomAffine(
- degrees=3,
- translate=(0.0, 0.05),
- scale=(0.4, 1.1),
- shear=(-40, 50),
- interpolation=InterpolationMode.BILINEAR,
- fill=0,
- ),
- ]
- )
-
-
-def generate_emnist_lines() -> None:
- """Generates a synthetic handwritten dataset and displays info,"""
- load_and_print_info(EMNISTLines)
diff --git a/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py
deleted file mode 100644
index a8998b9..0000000
--- a/text_recognizer/datasets/iam_dataset.py
+++ /dev/null
@@ -1,133 +0,0 @@
-"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
-import os
-from typing import Any, Dict, List
-import zipfile
-
-from boltons.cacheutils import cachedproperty
-import defusedxml.ElementTree as ET
-from loguru import logger
-import toml
-
-from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
-
-RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
-RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
-
-DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
-LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
-
-
-class IamDataset:
- """IAM dataset.
-
- "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
- which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
- From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
-
- The data split we will use is
- IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
- The validation set has been merged into the train set.
- The train set has 7,101 lines from 326 writers.
- The test set has 1,861 lines from 128 writers.
- The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.
-
- """
-
- def __init__(self) -> None:
- self.metadata = toml.load(METADATA_FILENAME)
-
- def load_or_generate_data(self) -> None:
- """Downloads IAM dataset if xml files does not exist."""
- if not self.xml_filenames:
- self._download_iam()
-
- @property
- def xml_filenames(self) -> List:
- """List of xml filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
-
- @property
- def form_filenames(self) -> List:
- """List of forms filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
-
- def _download_iam(self) -> None:
- curdir = os.getcwd()
- os.chdir(RAW_DATA_DIRNAME)
- _download_raw_dataset(self.metadata)
- _extract_raw_dataset(self.metadata)
- os.chdir(curdir)
-
- @property
- def form_filenames_by_id(self) -> Dict:
- """Creates a dictionary with filenames as keys and forms as values."""
- return {filename.stem: filename for filename in self.form_filenames}
-
- @cachedproperty
- def line_strings_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of line texts in it."""
- return {
- filename.stem: _get_line_strings_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- @cachedproperty
- def line_regions_by_id(self) -> Dict:
- """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
- return {
- filename.stem: _get_line_regions_from_xml_file(filename)
- for filename in self.xml_filenames
- }
-
- def __repr__(self) -> str:
- """Print info about dataset."""
- return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"
-
-
-def _extract_raw_dataset(metadata: Dict) -> None:
- logger.info("Extracting IAM data.")
- with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
- zip_file.extractall()
-
-
-def _get_line_strings_from_xml_file(filename: str) -> List[str]:
- """Get the text content of each line. Note that we replace &quot; with "."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [el.attrib["text"].replace("&quot;", '"') for el in xml_line_elements]
-
-
-def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
- """Get the line region dict for each line."""
- xml_root_element = ET.parse(filename).getroot() # nosec
- xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [_get_line_region_from_xml_element(el) for el in xml_line_elements]
-
-
-def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
- """Extracts coordinates for each line of text."""
- # TODO: fix input!
- word_elements = xml_line.findall("word/cmp")
- x1s = [int(el.attrib["x"]) for el in word_elements]
- y1s = [int(el.attrib["y"]) for el in word_elements]
- x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
- y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
- return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- }
-
-
-def main() -> None:
- """Initializes the dataset and print info about the dataset."""
- dataset = IamDataset()
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py
deleted file mode 100644
index 1cb84bd..0000000
--- a/text_recognizer/datasets/iam_lines_dataset.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""IamLinesDataset class."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import h5py
-from loguru import logger
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
-PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "iam_lines.h5"
-PROCESSED_DATA_URL = (
- "https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam_lines.h5"
-)
-
-
-class IamLinesDataset(Dataset):
- """IAM lines datasets for handwritten text lines."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- init_token: Optional[str] = None,
- pad_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- self.pad_token = "_" if pad_token is None else pad_token
-
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- init_token=init_token,
- pad_token=pad_token,
- eos_token=eos_token,
- lower=lower,
- )
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self.data.shape[1:] if self.data is not None else None
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return (
- self.targets.shape[1:] + (self.num_classes,)
- if self.targets is not None
- else None
- )
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- if not PROCESSED_DATA_FILENAME.exists():
- PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info("Downloading IAM lines...")
- download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
- with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self._data = f[f"x_{self.split}"][:]
- self._targets = f[f"y_{self.split}"][:]
- self._subsample()
-
- def __repr__(self) -> str:
- """Print info about the dataset."""
- return (
- "IAM Lines Dataset\n" # pylint: disable=no-member
- f"Number classes: {self.num_classes}\n"
- f"Mapping: {self.mapper.mapping}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- if self.transform:
- data = self.transform(data)
-
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets
diff --git a/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py
deleted file mode 100644
index 8ba5142..0000000
--- a/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ /dev/null
@@ -1,291 +0,0 @@
-"""IamParagraphsDataset class and functions for data processing."""
-import random
-from typing import Callable, Dict, List, Optional, Tuple, Union
-
-import click
-import cv2
-import h5py
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.transforms import ToTensor
-
-from text_recognizer import util
-from text_recognizer.datasets.dataset import Dataset
-from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import (
- compute_sha256,
- DATA_DIRNAME,
- download_url,
- EmnistMapper,
-)
-
-INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
-DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
-PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
-CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
-GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"
-
-PARAGRAPH_BUFFER = 50 # Pixels in the IAM form images to leave around the lines.
-TEST_FRACTION = 0.2
-SEED = 4711
-
-
-class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text."""
-
- def __init__(
- self,
- train: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- ) -> None:
- super().__init__(
- train=train,
- subsample_fraction=subsample_fraction,
- transform=transform,
- target_transform=target_transform,
- )
- # Load Iam dataset.
- self.iam_dataset = IamDataset()
-
- self.num_classes = 3
- self._input_shape = (256, 256)
- self._output_shape = self._input_shape + (self.num_classes,)
- self._ids = None
-
- def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
- """Fetches data, target pair of the dataset for a given and index or indices.
-
- Args:
- index (Union[int, Tensor]): Either a list or int of indices/index.
-
- Returns:
- Tuple[Tensor, Tensor]: Data target pair.
-
- """
- if torch.is_tensor(index):
- index = index.tolist()
-
- data = self.data[index]
- targets = self.targets[index]
-
- seed = np.random.randint(SEED)
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.transform:
- data = self.transform(data)
-
- random.seed(seed) # apply this seed to target tranfsorms
- torch.manual_seed(seed) # needed for torchvision 0.7
- if self.target_transform:
- targets = self.target_transform(targets)
-
- return data, targets.long()
-
- @property
- def ids(self) -> Tensor:
- """Ids of the dataset."""
- return self._ids
-
- def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
- """Get data target pair from id."""
- ind = self.ids.index(id_)
- return self.data[ind], self.targets[ind]
-
- def load_or_generate_data(self) -> None:
- """Load or generate dataset data."""
- num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
- num_targets = len(self.iam_dataset.line_regions_by_id)
-
- if num_actual < num_targets - 2:
- self._process_iam_paragraphs()
-
- self._data, self._targets, self._ids = _load_iam_paragraphs()
- self._get_random_split()
- self._subsample()
-
- def _get_random_split(self) -> None:
- np.random.seed(SEED)
- num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
- indices = np.random.permutation(self.data.shape[0])
- train_indices, test_indices = indices[:num_train], indices[num_train:]
- if self.train:
- self._data = self.data[train_indices]
- self._targets = self.targets[train_indices]
- else:
- self._data = self.data[test_indices]
- self._targets = self.targets[test_indices]
-
- def _process_iam_paragraphs(self) -> None:
- """Crop the part with the text.
-
- For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
- self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
- corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
- """
- crop_dims = self._decide_on_crop_dims()
- CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
- GT_DIRNAME.mkdir(parents=True, exist_ok=True)
- logger.info(
- f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
- )
- for filename in self.iam_dataset.form_filenames:
- id_ = filename.stem
- line_region = self.iam_dataset.line_regions_by_id[id_]
- _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)
-
- def _decide_on_crop_dims(self) -> Tuple[int, int]:
- """Decide on the dimensions to crop out of the form image.
-
- Since image width is larger than a comfortable crop around the longest paragraph,
- we will make the crop a square form factor.
- And since the found dimensions 610x610 are pretty close to 512x512,
- we might as well resize crops and make it exactly that, which lets us
- do all kinds of power-of-2 pooling and upsampling should we choose to.
-
- Returns:
- Tuple[int, int]: A tuple of crop dimensions.
-
- Raises:
- RuntimeError: When max crop height is larger than max crop width.
-
- """
-
- sample_form_filename = self.iam_dataset.form_filenames[0]
- sample_image = util.read_image(sample_form_filename, grayscale=True)
- max_crop_width = sample_image.shape[1]
- max_crop_height = _get_max_paragraph_crop_height(
- self.iam_dataset.line_regions_by_id
- )
- if not max_crop_height <= max_crop_width:
- raise RuntimeError(
- f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
- )
-
- crop_dims = (max_crop_width, max_crop_width)
- logger.info(
- f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
- )
- logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
- return crop_dims
-
- def __repr__(self) -> str:
- """Return info about the dataset."""
- return (
- "IAM Paragraph Dataset\n" # pylint: disable=no-member
- f"Num classes: {self.num_classes}\n"
- f"Data: {self.data.shape}\n"
- f"Targets: {self.targets.shape}\n"
- )
-
-
-def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
- heights = []
- for regions in line_regions_by_id.values():
- min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- heights.append(height)
- return max(heights)
-
-
-def _crop_paragraph_image(
- filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
-) -> None:
- image = util.read_image(filename, grayscale=True)
-
- min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
- max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
- height = max_y2 - min_y1
- crop_height = crop_dims[0]
- buffer = (crop_height - height) // 2
-
- # Generate image crop.
- image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
- try:
- image_crop[buffer : buffer + height] = image[min_y1:max_y2]
- except Exception as e: # pylint: disable=broad-except
- logger.error(f"Rescued {filename}: {e}")
- return
-
- # Generate ground truth.
- gt_image = np.zeros_like(image_crop, dtype=np.uint8)
- for index, region in enumerate(line_regions):
- gt_image[
- (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
- region["x1"] : region["x2"],
- ] = (index % 2 + 1)
-
- # Generate image for debugging.
- import matplotlib.pyplot as plt
-
- cmap = plt.get_cmap("Set1")
- image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
- for index, region in enumerate(line_regions):
- color = [255 * _ for _ in cmap(index)[:-1]]
- cv2.rectangle(
- image_crop_for_debug,
- (region["x1"], region["y1"] - min_y1 + buffer),
- (region["x2"], region["y2"] - min_y1 + buffer),
- color,
- 3,
- )
- image_crop_for_debug = cv2.resize(
- image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
- )
- util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
- util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")
-
- gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
- util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")
-
-
-def _load_iam_paragraphs() -> None:
- logger.info("Loading IAM paragraph crops and ground truth from image files...")
- images = []
- gt_images = []
- ids = []
- for filename in CROPS_DIRNAME.glob("*.jpg"):
- id_ = filename.stem
- image = util.read_image(filename, grayscale=True)
- image = 1.0 - image / 255
-
- gt_filename = GT_DIRNAME / f"{id_}.png"
- gt_image = util.read_image(gt_filename, grayscale=True)
-
- images.append(image)
- gt_images.append(gt_image)
- ids.append(id_)
- images = np.array(images).astype(np.float32)
- gt_images = np.array(gt_images).astype(np.uint8)
- ids = np.array(ids)
- return images, gt_images, ids
-
-
-@click.command()
-@click.option(
- "--subsample_fraction",
- type=float,
- default=None,
- help="The subsampling factor of the dataset.",
-)
-def main(subsample_fraction: float) -> None:
- """Load dataset and print info."""
- logger.info("Creating train set...")
- dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
- logger.info("Creating test set...")
- dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
- dataset.load_or_generate_data()
- print(dataset)
-
-
-if __name__ == "__main__":
- main()
diff --git a/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py
deleted file mode 100644
index a93eb00..0000000
--- a/text_recognizer/datasets/iam_preprocessor.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""Preprocessor for extracting word letters from the IAM dataset.
-
-The code is mostly stolen from:
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
-"""
-
-import collections
-import itertools
-from pathlib import Path
-import re
-from typing import List, Optional, Union
-
-import click
-from loguru import logger
-import torch
-
-
-def load_metadata(
- data_dir: Path, wordsep: str, use_words: bool = False
-) -> collections.defaultdict:
- """Loads IAM metadata and returns it as a dictionary."""
- forms = collections.defaultdict(list)
- filename = "words.txt" if use_words else "lines.txt"
-
- with open(data_dir / "ascii" / filename, "r") as f:
- lines = (line.strip().split() for line in f if line[0] != "#")
- for line in lines:
- # Skip word segmentation errors.
- if use_words and line[1] == "err":
- continue
- text = " ".join(line[8:])
-
- # Remove garbage tokens:
- text = text.replace("#", "")
-
- # Swap word sep form | to wordsep
- text = re.sub(r"\|+|\s", wordsep, text).strip(wordsep)
- form_key = "-".join(line[0].split("-")[:2])
- line_key = "-".join(line[0].split("-")[:3])
- box_idx = 4 - use_words
- box = tuple(int(val) for val in line[box_idx : box_idx + 4])
- forms[form_key].append({"key": line_key, "box": box, "text": text})
- return forms
-
-
-class Preprocessor:
- """A preprocessor for the IAM dataset."""
-
- # TODO: add lower case only to when generating...
-
- def __init__(
- self,
- data_dir: Union[str, Path],
- num_features: int,
- tokens_path: Optional[Union[str, Path]] = None,
- lexicon_path: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- self.wordsep = "▁"
- self._use_word = use_words
- self._prepend_wordsep = prepend_wordsep
-
- self.data_dir = Path(data_dir)
-
- self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
-
- # Load the set of graphemes:
- graphemes = set()
- for _, form in self.forms.items():
- for line in form:
- graphemes.update(line["text"].lower())
- self.graphemes = sorted(graphemes)
-
- # Build the token-to-index and index-to-token maps.
- if tokens_path is not None:
- with open(tokens_path, "r") as f:
- self.tokens = [line.strip() for line in f]
- else:
- self.tokens = self.graphemes
-
- if lexicon_path is not None:
- with open(lexicon_path, "r") as f:
- lexicon = (line.strip().split() for line in f)
- lexicon = {line[0]: line[1:] for line in lexicon}
- self.lexicon = lexicon
- else:
- self.lexicon = None
-
- self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
- self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
- self.num_features = num_features
- self.text = []
-
- @property
- def num_tokens(self) -> int:
- """Returns the number or tokens."""
- return len(self.tokens)
-
- @property
- def use_words(self) -> bool:
- """If words are used."""
- return self._use_word
-
- def extract_train_text(self) -> None:
- """Extracts training text."""
- keys = []
- with open(self.data_dir / "task" / "trainset.txt") as f:
- keys.extend((line.strip() for line in f))
-
- for _, examples in self.forms.items():
- for example in examples:
- if example["key"] not in keys:
- continue
- self.text.append(example["text"].lower())
-
- def to_index(self, line: str) -> torch.LongTensor:
- """Converts text to a tensor of indices."""
- token_to_index = self.graphemes_to_index
- if self.lexicon is not None:
- if len(line) > 0:
- # If the word is not found in the lexicon, fall back to letters.
- line = [
- t
- for w in line.split(self.wordsep)
- for t in self.lexicon.get(w, self.wordsep + w)
- ]
- token_to_index = self.tokens_to_index
- if self._prepend_wordsep:
- line = itertools.chain([self.wordsep], line)
- return torch.LongTensor([token_to_index[t] for t in line])
-
- def to_text(self, indices: List[int]) -> str:
- """Converts indices to text."""
- # Roughly the inverse of `to_index`
- encoding = self.graphemes
- if self.lexicon is not None:
- encoding = self.tokens
- return self._post_process(encoding[i] for i in indices)
-
- def tokens_to_text(self, indices: List[int]) -> str:
- """Converts tokens to text."""
- return self._post_process(self.tokens[i] for i in indices)
-
- def _post_process(self, indices: List[int]) -> str:
- """A list join."""
- return "".join(indices).strip(self.wordsep)
-
-
-@click.command()
-@click.option("--data_dir", type=str, default=None, help="Path to iam dataset")
-@click.option(
- "--use_words", is_flag=True, help="Load word segmented dataset instead of lines"
-)
-@click.option(
- "--save_text", type=str, default=None, help="Path to save parsed train text"
-)
-@click.option("--save_tokens", type=str, default=None, help="Path to save tokens")
-def cli(
- data_dir: Optional[str],
- use_words: bool,
- save_text: Optional[str],
- save_tokens: Optional[str],
-) -> None:
- """CLI for extracting text data from the iam dataset."""
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
-
- preprocessor = Preprocessor(data_dir, 64, use_words=use_words)
- preprocessor.extract_train_text()
-
- processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
-
- if save_text is not None:
- logger.info("Saving training text")
- with open(processed_dir / save_text, "w") as f:
- f.write("\n".join(t for t in preprocessor.text))
-
- if save_tokens is not None:
- logger.info("Saving tokens")
- with open(processed_dir / save_tokens, "w") as f:
- f.write("\n".join(preprocessor.tokens))
-
-
-if __name__ == "__main__":
- cli()
diff --git a/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
deleted file mode 100644
index 53b781c..0000000
--- a/text_recognizer/datasets/sentence_generator.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""Downloading the Brown corpus with NLTK for sentence generating."""
-
-import itertools
-import re
-import string
-from typing import Optional
-
-import nltk
-from nltk.corpus.reader.util import ConcatenatedCorpusView
-import numpy as np
-
-from text_recognizer.datasets.util import DATA_DIRNAME
-
-NLTK_DATA_DIRNAME = DATA_DIRNAME / "downloaded" / "nltk"
-
-
-class SentenceGenerator:
- """Generates text sentences using the Brown corpus."""
-
- def __init__(self, max_length: Optional[int] = None) -> None:
- """Loads the corpus and sets word start indices."""
- self.corpus = brown_corpus()
- self.word_start_indices = [0] + [
- _.start(0) + 1 for _ in re.finditer(" ", self.corpus)
- ]
- self.max_length = max_length
-
- def generate(self, max_length: Optional[int] = None) -> str:
- """Generates a word or sentences from the Brown corpus.
-
- Sample a string from the Brown corpus of length at least one word and at most max_length, padding to
- max_length with the '_' characters if sentence is shorter.
-
- Args:
- max_length (Optional[int]): The maximum number of characters in the sentence. Defaults to None.
-
- Returns:
- str: A sentence from the Brown corpus.
-
- Raises:
- ValueError: If max_length was not specified at initialization and not given as an argument.
-
- """
- if max_length is None:
- max_length = self.max_length
- if max_length is None:
- raise ValueError(
- "Must provide max_length to this method or when making this object."
- )
-
- for _ in range(10):
- try:
- index = np.random.randint(0, len(self.word_start_indices) - 1)
- start_index = self.word_start_indices[index]
- end_index_candidates = []
- for index in range(index + 1, len(self.word_start_indices)):
- if self.word_start_indices[index] - start_index > max_length:
- break
- end_index_candidates.append(self.word_start_indices[index])
- end_index = np.random.choice(end_index_candidates)
- sampled_text = self.corpus[start_index:end_index].strip()
- return sampled_text
- except Exception:
- pass
- raise RuntimeError("Was not able to generate a valid string")
-
-
-def brown_corpus() -> str:
- """Returns a single string with the Brown corpus with all punctuations stripped."""
- sentences = load_nltk_brown_corpus()
- corpus = " ".join(itertools.chain.from_iterable(sentences))
- corpus = corpus.translate({ord(c): None for c in string.punctuation})
- corpus = re.sub(" +", " ", corpus)
- return corpus
-
-
-def load_nltk_brown_corpus() -> ConcatenatedCorpusView:
- """Load the Brown corpus using the NLTK library."""
- nltk.data.path.append(NLTK_DATA_DIRNAME)
- try:
- nltk.corpus.brown.sents()
- except LookupError:
- NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
- nltk.download("brown", download_dir=NLTK_DATA_DIRNAME)
- return nltk.corpus.brown.sents()
diff --git a/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py
deleted file mode 100644
index b6a48f5..0000000
--- a/text_recognizer/datasets/transforms.py
+++ /dev/null
@@ -1,266 +0,0 @@
-"""Transforms for PyTorch datasets."""
-from abc import abstractmethod
-from pathlib import Path
-import random
-from typing import Any, Optional, Union
-
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-from torchvision import transforms
-from torchvision.transforms import (
- ColorJitter,
- Compose,
- Normalize,
- RandomAffine,
- RandomHorizontalFlip,
- RandomRotation,
- ToPILImage,
- ToTensor,
-)
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class RandomResizeCrop:
- """Image transform with random resize and crop applied.
-
- Stolen from
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
- """
-
- def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
- self.jitter = jitter
- self.ratio = ratio
-
- def __call__(self, img: np.ndarray) -> np.ndarray:
- """Applies random crop and rotation to an image."""
- w, h = img.size
-
- # pad with white:
- img = transforms.functional.pad(img, self.jitter, fill=255)
-
- # crop at random (x, y):
- x = self.jitter + random.randint(-self.jitter, self.jitter)
- y = self.jitter + random.randint(-self.jitter, self.jitter)
-
- # randomize aspect ratio:
- size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
- size = (h, int(size_w))
- img = transforms.functional.resized_crop(img, y, x, h, w, size)
- return img
-
-
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
-class Resize:
- """Resizes a tensor to a specified width."""
-
- def __init__(self, width: int = 952) -> None:
- # The default is 952 because of the IAM dataset.
- self.width = width
-
- def __call__(self, image: Tensor) -> Tensor:
- """Resize tensor in the last dimension."""
- return F.interpolate(image, size=self.width, mode="nearest")
-
-
-class AddTokens:
- """Adds start of sequence and end of sequence tokens to target tensor."""
-
- def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token,
- )
- self.pad_value = self.emnist_mapper(self.pad_token)
- self.eos_value = self.emnist_mapper(self.eos_token)
-
- def __call__(self, target: Tensor) -> Tensor:
- """Adds a sos token to the begining and a eos token to the end of a target sequence."""
- dtype, device = target.dtype, target.device
-
- # Find the where padding starts.
- pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
-
- target[pad_index] = self.eos_value
-
- if self.init_token is not None:
- self.sos_value = self.emnist_mapper(self.init_token)
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
- target = torch.cat([sos, target], dim=0)
-
- return target
-
-
-class ApplyContrast:
- """Sets everything below a threshold to zero, i.e. increase contrast."""
-
- def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
- self.low = low
- self.high = high
-
- def __call__(self, x: Tensor) -> Tensor:
- """Apply mask binary mask to input tensor."""
- mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
- return x * mask
-
-
-class Unsqueeze:
- """Add a dimension to the tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Adds dim."""
- return x.unsqueeze(0)
-
-
-class Squeeze:
- """Removes the first dimension of a tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Removes first dim."""
- return x.squeeze(0)
-
-
-class ToLower:
- """Converts target to lower case."""
-
- def __call__(self, target: Tensor) -> Tensor:
- """Corrects index value in target tensor."""
- device = target.device
- return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
-
-
-class ToCharcters:
- """Converts integers to characters."""
-
- def __init__(
- self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
- ) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=lower,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
- )
-
- def __call__(self, y: Tensor) -> str:
- """Converts a Tensor to a str."""
- return (
- "".join([self.emnist_mapper(int(i)) for i in y])
- .strip("_")
- .replace(" ", "▁")
- )
-
-
-class WordPieces:
- """Abstract transform for word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- self.preprocessor = Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- )
-
- @abstractmethod
- def __call__(self, *args, **kwargs) -> Any:
- """Transforms input."""
- ...
-
-
-class ToWordPieces(WordPieces):
- """Transforms str to word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, line: str) -> Tensor:
- """Transforms str to word pieces."""
- return self.preprocessor.to_index(line)
-
-
-class ToText(WordPieces):
- """Takes word pieces and converts them to text."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, x: Tensor) -> str:
- """Converts tensor to text."""
- return self.preprocessor.to_text(x.tolist())
diff --git a/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py
deleted file mode 100644
index da87756..0000000
--- a/text_recognizer/datasets/util.py
+++ /dev/null
@@ -1,209 +0,0 @@
-"""Util functions for datasets."""
-import hashlib
-import json
-import os
-from pathlib import Path
-import string
-from typing import Dict, List, Optional, Union
-from urllib.request import urlretrieve
-
-from loguru import logger
-import numpy as np
-import torch
-from torch import Tensor
-from torchvision.datasets import EMNIST
-from tqdm import tqdm
-
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-
-
-def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(
- self,
- pad_token: str,
- init_token: Optional[str] = None,
- eos_token: Optional[str] = None,
- lower: bool = False,
- ) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- self.lower = lower
-
- self.essentials = self._load_emnist_essentials()
- # Load dataset information.
- self._mapping = dict(self.essentials["mapping"])
- self._augment_emnist_mapping()
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (
- (isinstance(token, np.uint8) or isinstance(token, int))
- or torch.is_tensor(token)
- and int(token) in self.mapping
- ):
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self) -> None:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- if self.lower:
- self._mapping = {
- k: str(v)
- for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
- }
-
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol, and acts as blank symbol as well.
- extra_symbols.append(self.pad_token)
-
- if self.init_token is not None:
- extra_symbols.append(self.init_token)
-
- if self.eos_token is not None:
- extra_symbols.append(self.eos_token)
-
- max_key = max(self.mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- self._mapping = {**self.mapping, **extra_mapping}
-
-
-def compute_sha256(filename: Union[Path, str]) -> str:
- """Returns the SHA256 checksum of a file."""
- with open(filename, "rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
-
-
-class TqdmUpTo(tqdm):
- """TQDM progress bar when downloading files.
-
- From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
-
- """
-
- def update_to(
- self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None
- ) -> None:
- """Updates the progress bar.
-
- Args:
- blocks (int): Number of blocks transferred so far. Defaults to 1.
- block_size (int): Size of each block, in tqdm units. Defaults to 1.
- total_size (Optional[int]): Total size in tqdm units. Defaults to None.
- """
- if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
- self.update(blocks * block_size - self.n)
-
-
-def download_url(url: str, filename: str) -> None:
- """Downloads a file from url to filename, with a progress bar."""
- with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
- urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec
-
-
-def _download_raw_dataset(metadata: Dict) -> None:
- if os.path.exists(metadata["filename"]):
- return
- logger.info(f"Downloading raw dataset from {metadata['url']}...")
- download_url(metadata["url"], metadata["filename"])
- logger.info("Computing SHA-256...")
- sha256 = compute_sha256(metadata["filename"])
- if sha256 != metadata["sha256"]:
- raise ValueError(
- "Downloaded data file SHA-256 does not match that listed in metadata document."
- )