diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-24 22:15:54 +0100 |
commit | 8248f173132dfb7e47ec62b08e9235990c8626e3 (patch) | |
tree | 2f3ff85602cbc08b7168bf4f0d3924d32a689852 /text_recognizer/datasets/emnist.py | |
parent | 74c907a17379688967dc4b3f41a44ba83034f5e0 (diff) |
renamed datasets to data, added iam refactor
Diffstat (limited to 'text_recognizer/datasets/emnist.py')
-rw-r--r-- | text_recognizer/datasets/emnist.py | 210 |
1 files changed, 0 insertions, 210 deletions
diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py deleted file mode 100644 index 66101b5..0000000 --- a/text_recognizer/datasets/emnist.py +++ /dev/null @@ -1,210 +0,0 @@ -"""EMNIST dataset: downloads it from FSDL aws url if not present.""" -from pathlib import Path -from typing import Sequence, Tuple -import json -import os -import shutil -import zipfile - -import h5py -import numpy as np -from loguru import logger -import toml -import torch -from torch.utils.data import random_split -from torchvision import transforms - -from text_recognizer.datasets.base_dataset import BaseDataset -from text_recognizer.datasets.base_data_module import ( - BaseDataModule, - load_and_print_info, -) -from text_recognizer.datasets.download_utils import download_dataset - - -SEED = 4711 -NUM_SPECIAL_TOKENS = 4 -SAMPLE_TO_BALANCE = True - -RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" -METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" -DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "emnist" -PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" -ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" - - -class EMNIST(BaseDataModule): - """ - "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 - and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." - From https://www.nist.gov/itl/iad/image-group/emnist-dataset - - The data split we will use is - EMNIST ByClass: 814,255 characters. 62 unbalanced classes. - """ - - def __init__( - self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8 - ) -> None: - super().__init__(batch_size, num_workers) - if not ESSENTIALS_FILENAME.exists(): - _download_and_process_emnist() - with ESSENTIALS_FILENAME.open() as f: - essentials = json.load(f) - self.train_fraction = train_fraction - self.mapping = list(essentials["characters"]) - self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} - self.data_train = None - self.data_val = None - self.data_test = None - self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, *essentials["input_shape"]) - self.output_dims = (1,) - - def prepare_data(self) -> None: - if not PROCESSED_DATA_FILENAME.exists(): - _download_and_process_emnist() - - def setup(self, stage: str = None) -> None: - if stage == "fit" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_train = f["x_train"][:] - self.y_train = f["y_train"][:].squeeze().astype(int) - - dataset_train = BaseDataset( - self.x_train, self.y_train, transform=self.transform - ) - train_size = int(self.train_fraction * len(dataset_train)) - val_size = len(dataset_train) - train_size - self.data_train, self.data_val = random_split( - dataset_train, [train_size, val_size], generator=torch.Generator() - ) - - if stage == "test" or stage is None: - with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.x_test = f["x_test"][:] - self.y_test = f["y_test"][:].squeeze().astype(int) - self.data_test = BaseDataset( - self.x_test, self.y_test, transform=self.transform - ) - - def __repr__(self) -> str: - basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" - if not any([self.data_train, self.data_val, self.data_test]): - return basic - - datum, target = next(iter(self.train_dataloader())) - data = ( - f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" - f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" - f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" - ) - - return basic + data - - -def _download_and_process_emnist() -> None: - metadata = toml.load(METADATA_FILENAME) - download_dataset(metadata, DL_DATA_DIRNAME) - _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) - - -def _process_raw_dataset(filename: str, dirname: Path) -> None: - logger.info("Unzipping EMNIST...") - curdir = os.getcwd() - os.chdir(dirname) - content = zipfile.ZipFile(filename, "r") - content.extract("matlab/emnist-byclass.mat") - - from scipy.io import loadmat - - logger.info("Loading training data from .mat file") - data = loadmat("matlab/emnist-byclass.mat") - x_train = ( - data["dataset"]["train"][0, 0]["images"][0, 0] - .reshape(-1, 28, 28) - .swapaxes(1, 2) - ) - y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - x_test = ( - data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) - ) - y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS - - if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") - x_train, y_train = _sample_to_balance(x_train, y_train) - x_test, y_test = _sample_to_balance(x_test, y_test) - - logger.info("Saving to HDF5 in a compressed format...") - PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) - with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: - f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") - f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") - f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") - f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") - mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} - characters = _augment_emnist_characters(mapping.values()) - essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} - - with ESSENTIALS_FILENAME.open(mode="w") as f: - json.dump(essentials, f) - - logger.info("Cleaning up...") - shutil.rmtree("matlab") - os.chdir(curdir) - - -def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Balances the dataset by taking the mean number of instances per class.""" - np.random.seed(SEED) - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - indices = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - return x_sampled, y_sampled - - -def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: - """Augment the mapping with extra symbols.""" - # Extra characters from the IAM dataset. - iam_characters = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # Also add special tokens for: - # - CTC blank token at index 0 - # - Start token at index 1 - # - End token at index 2 - # - Padding token at index 3 - # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! - return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] - - -def download_emnist() -> None: - """Download dataset from internet, if it does not exists, and displays info.""" - load_and_print_info(EMNIST) |