summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py210
1 files changed, 210 insertions, 0 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
new file mode 100644
index 0000000..7f67893
--- /dev/null
+++ b/text_recognizer/data/emnist.py
@@ -0,0 +1,210 @@
+"""EMNIST dataset: downloads it from FSDL aws url if not present."""
+from pathlib import Path
+from typing import Sequence, Tuple
+import json
+import os
+import shutil
+import zipfile
+
+import h5py
+import numpy as np
+from loguru import logger
+import toml
+import torch
+from torch.utils.data import random_split
+from torchvision import transforms
+
+from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_data_module import (
+ BaseDataModule,
+ load_and_print_info,
+)
+from text_recognizer.data.download_utils import download_dataset
+
+
+SEED = 4711
+NUM_SPECIAL_TOKENS = 4
+SAMPLE_TO_BALANCE = True
+
+RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist"
+METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
+DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist"
+PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist"
+PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+
+
+class EMNIST(BaseDataModule):
+ """
+ "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
+ and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
+ From https://www.nist.gov/itl/iad/image-group/emnist-dataset
+
+ The data split we will use is
+ EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
+ """
+
+ def __init__(
+ self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+ if not ESSENTIALS_FILENAME.exists():
+ _download_and_process_emnist()
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ self.train_fraction = train_fraction
+ self.mapping = list(essentials["characters"])
+ self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
+ self.data_train = None
+ self.data_val = None
+ self.data_test = None
+ self.transform = transforms.Compose([transforms.ToTensor()])
+ self.dims = (1, *essentials["input_shape"])
+ self.output_dims = (1,)
+
+ def prepare_data(self) -> None:
+ if not PROCESSED_DATA_FILENAME.exists():
+ _download_and_process_emnist()
+
+ def setup(self, stage: str = None) -> None:
+ if stage == "fit" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.x_train = f["x_train"][:]
+ self.y_train = f["y_train"][:].squeeze().astype(int)
+
+ dataset_train = BaseDataset(
+ self.x_train, self.y_train, transform=self.transform
+ )
+ train_size = int(self.train_fraction * len(dataset_train))
+ val_size = len(dataset_train) - train_size
+ self.data_train, self.data_val = random_split(
+ dataset_train, [train_size, val_size], generator=torch.Generator()
+ )
+
+ if stage == "test" or stage is None:
+ with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
+ self.x_test = f["x_test"][:]
+ self.y_test = f["y_test"][:].squeeze().astype(int)
+ self.data_test = BaseDataset(
+ self.x_test, self.y_test, transform=self.transform
+ )
+
+ def __repr__(self) -> str:
+ basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n"
+ if not any([self.data_train, self.data_val, self.data_test]):
+ return basic
+
+ datum, target = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n"
+ f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n"
+ )
+
+ return basic + data
+
+
+def _download_and_process_emnist() -> None:
+ metadata = toml.load(METADATA_FILENAME)
+ download_dataset(metadata, DL_DATA_DIRNAME)
+ _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
+
+
+def _process_raw_dataset(filename: str, dirname: Path) -> None:
+ logger.info("Unzipping EMNIST...")
+ curdir = os.getcwd()
+ os.chdir(dirname)
+ content = zipfile.ZipFile(filename, "r")
+ content.extract("matlab/emnist-byclass.mat")
+
+ from scipy.io import loadmat
+
+ logger.info("Loading training data from .mat file")
+ data = loadmat("matlab/emnist-byclass.mat")
+ x_train = (
+ data["dataset"]["train"][0, 0]["images"][0, 0]
+ .reshape(-1, 28, 28)
+ .swapaxes(1, 2)
+ )
+ y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+ x_test = (
+ data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
+ )
+ y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
+
+ if SAMPLE_TO_BALANCE:
+ logger.info("Balancing classes to reduce amount of data")
+ x_train, y_train = _sample_to_balance(x_train, y_train)
+ x_test, y_test = _sample_to_balance(x_test, y_test)
+
+ logger.info("Saving to HDF5 in a compressed format...")
+ PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
+ with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
+ f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
+ f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
+ f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
+ f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
+
+ logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
+ characters = _augment_emnist_characters(mapping.values())
+ essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
+
+ with ESSENTIALS_FILENAME.open(mode="w") as f:
+ json.dump(essentials, f)
+
+ logger.info("Cleaning up...")
+ shutil.rmtree("matlab")
+ os.chdir(curdir)
+
+
+def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Balances the dataset by taking the mean number of instances per class."""
+ np.random.seed(SEED)
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ indices = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(indices, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ return x_sampled, y_sampled
+
+
+def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
+ """Augment the mapping with extra symbols."""
+ # Extra characters from the IAM dataset.
+ iam_characters = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # Also add special tokens for:
+ # - CTC blank token at index 0
+ # - Start token at index 1
+ # - End token at index 2
+ # - Padding token at index 3
+ # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this!
+ return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters]
+
+
+def download_emnist() -> None:
+ """Download dataset from internet, if it does not exists, and displays info."""
+ load_and_print_info(EMNIST)