diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 20:03:10 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-21 20:03:10 +0100 |
commit | aac452a2dc008338cb543549652da293c14b6b4e (patch) | |
tree | 6d018841e28f22eee5289f6cc59c28100a9d023d /text_recognizer | |
parent | a3a40c9c0118039460d5c9fba6a74edc0cdba106 (diff) |
Refactor EMNIST dataset
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/datasets/base_data_module.py | 69 | ||||
-rw-r--r-- | text_recognizer/datasets/download_utils.py | 73 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist.py | 194 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_dataset.py | 131 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json | 1 |
5 files changed, 336 insertions, 132 deletions
diff --git a/text_recognizer/datasets/base_data_module.py b/text_recognizer/datasets/base_data_module.py new file mode 100644 index 0000000..09a0a43 --- /dev/null +++ b/text_recognizer/datasets/base_data_module.py @@ -0,0 +1,69 @@ +"""Base lightning DataModule class.""" +from pathlib import Path +from typing import Dict + +import pytorch_lightning as pl +from torch.utils.data import DataLoader + + +def load_and_print_info(data_module_class: type) -> None: + """Load EMNISTLines and prints info.""" + dataset = data_module_class() + dataset.prepare_data() + dataset.setup() + print(dataset) + + +class BaseDataModule(pl.LightningDataModule): + """Base PyTorch Lightning DataModule.""" + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + + # Placeholders for subclasses. + self.dims = None + self.output_dims = None + self.mapping = None + + @classmethod + def data_dirname(cls) -> Path: + """Return the path to the base data directory.""" + return Path(__file__).resolve().parents[2] / "data" + + def config(self) -> Dict: + """Return important settings of the dataset.""" + return {"input_dim": self.dims, "output_dims": self.output_dims, "mapping": self.mapping} + + def prepare_data(self) -> None: + """Prepare data for training.""" + pass + + def setup(self, stage: Any = None) -> None: + """Split into train, val, test, and set dims. + + Should assign `torch Dataset` objects to self.data_train, self.data_val, and + optionally self.data_test. + + Args: + stage (Any): Variable to set splits. + + """ + self.data_train = None + self.data_val = None + self.data_test = None + + + def train_dataloader(self) -> DataLoader: + """Retun DataLoader for train data.""" + return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + + def val_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + + def test_dataloader(self) -> DataLoader: + """Return DataLoader for val data.""" + return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) + diff --git a/text_recognizer/datasets/download_utils.py b/text_recognizer/datasets/download_utils.py new file mode 100644 index 0000000..7a2cab8 --- /dev/null +++ b/text_recognizer/datasets/download_utils.py @@ -0,0 +1,73 @@ +"""Util functions for downloading datasets.""" +import hashlib +from pathlib import Path +from typing import Dict, List, Optional +from urllib.request import urlretrieve + +from loguru import logger +from tqdm import tqdm + + +def _compute_sha256(filename: Path) -> str: + """Returns the SHA256 checksum of a file.""" + with filename.open(mode="rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +class TqdmUpTo(tqdm): + """TQDM progress bar when downloading files. + + From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py + + """ + + def update_to( + self, blocks: int = 1, block_size: int = 1, total_size: Optional[int] = None + ) -> None: + """Updates the progress bar. + + Args: + blocks (int): Number of blocks transferred so far. Defaults to 1. + block_size (int): Size of each block, in tqdm units. Defaults to 1. + total_size (Optional[int]): Total size in tqdm units. Defaults to None. + """ + if total_size is not None: + self.total = total_size # pylint: disable=attribute-defined-outside-init + self.update(blocks * block_size - self.n) + + +def _download_url(url: str, filename: str) -> None: + """Downloads a file from url to filename, with a progress bar.""" + with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + urlretrieve(url, filename, reporthook=t.update_to, data=None) # nosec + + +def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: + """Downloads dataset using a metadata file. + + Args: + metadata (Dict): A metadata file of the dataset. + dl_dir (Path): Download directory for the dataset. + + Returns: + Optional[Path]: Returns filename if dataset is downloaded, None if it already + exists. + + Raises: + ValueError: If the SHA-256 value is not the same between the dataset and + the metadata file. + + """ + dl_dir.mkdir(parents=True, exist_ok=True) + filename = dl_dir / metadata["filename"] + if filename.exists(): + return + logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + _download_url(metadata["url"], filename) + logger.info("Computing the SHA-256...") + sha256 = _compute_sha256(filename) + if sha256 != metadata["sha256"]: + raise ValueError( + "Downloaded data file SHA-256 does not match that listed in metadata document." + ) + return filename diff --git a/text_recognizer/datasets/emnist.py b/text_recognizer/datasets/emnist.py new file mode 100644 index 0000000..e99dbfd --- /dev/null +++ b/text_recognizer/datasets/emnist.py @@ -0,0 +1,194 @@ +"""EMNIST dataset: downloads it from FSDL aws url if not present.""" +from pathlib import Path +from typing import Sequence, Tuple +import json +import os +import shutil +import zipfile + +import h5py +import numpy as np +from loguru import logger +import toml +import torch +from torch.utils.data import random_split +from torchvision import transforms + +from text_recognizer.datasets.base_dataset import BaseDataset +from text_recognizer.datasets.base_data_module import BaseDataModule, load_print_info +from text_recognizer.datasets.download_utils import download_dataset + + +SEED = 4711 +NUM_SPECIAL_TOKENS = 4 +SAMPLE_TO_BALANCE = True + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "emnist" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "emnist" +PROCESSED_DATA_DIRNAME = BaseDataset.data_dirname() / "processed" / "emnist" +PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnsit_essentials.json" + + +class EMNIST(BaseDataModule): + """ + "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 + and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." + From https://www.nist.gov/itl/iad/image-group/emnist-dataset + + The data split we will use is + EMNIST ByClass: 814,255 characters. 62 unbalanced classes. + """ + + def __init__(self, batch_size: int = 128, num_workers: int = 0, train_fraction: float = 0.8) -> None: + super().__init__(batch_size, num_workers) + if not ESSENTIALS_FILENAME.exists(): + _download_and_process_emnist() + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + self.train_fraction = train_fraction + self.mapping = list(essentials["characters"]) + self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} + self.data_train = None + self.data_val = None + self.data_test = None + self.transform = transforms.Compose([transforms.ToTensor()]) + self.dims = (1, *essentials["input_shape"]) + self.output_dims = (1,) + + def prepare_data(self) -> None: + if not PROCESSED_DATA_FILENAME.exists(): + _download_and_process_emnist() + + def setup(self, stage: str = None) -> None: + if stage == "fit" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + data = f["x_train"][:] + targets = f["y_train"][:] + + dataset_train = BaseDataset(data, targets, transform=self.transform) + train_size = int(self.train_fraction * len(dataset_train)) + val_size = len(dataset_train) - train_size + self.data_train, self.data_val = random_split(dataset_train, [train_size, val_size], generator=torch.Generator()) + + if stage == "test" or stage is None: + with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: + data = f["x_test"][:] + targets = f["y_test"][:] + self.data_test = BaseDataset(data, targets, transform=self.transform) + + + def __repr__(self) -> str: + basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.dims}\n" + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + datum, target = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" + f"Batch x stats: {(datum.shape, datum.dtype, datum.min(), datum.mean(), datum.std(), datum.max())}\n" + f"Batch y stats: {(target.shape, target.dtype, target.min(), target.max())}\n" + ) + + return basic + data + + +def _download_and_process_emnist() -> None: + metadata = toml.load(METADATA_FILENAME) + download_dataset(metadata, DL_DATA_DIRNAME) + _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) + + +def _process_raw_dataset(filename: str, dirname: Path) -> None: + logger.info("Unzipping EMNIST...") + curdir = os.getcwd() + os.chdir(dirname) + content = zipfile.ZipFile(filename, "r") + content.extract("matlab/emnist-byclass.mat") + + from scipy.io import loadmat + + logger.info("Loading training data from .mat file") + data = loadmat("matlab/emnist-byclass.mat") + x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) + y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS + + if SAMPLE_TO_BALANCE: + logger.info("Balancing classes to reduce amount of data") + x_train, y_train = _sample_to_balance(x_train, y_train) + x_test, y_test = _sample_to_balance(x_test, y_test) + + + logger.info("Saving to HDF5 in a compressed format...") + PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) + with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: + f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") + f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") + f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") + f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") + + logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} + characters = _augment_emnist_characters(mapping.values()) + essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} + + with ESSENTIALS_FILENAME.open(mode="w") as f: + json.dump(essentials, f) + + logger.info("Cleaning up...") + shutil.rmtree("matlab") + os.chdir(curdir) + + +def _sample_to_balance(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Balances the dataset by taking the mean number of instances per class.""" + np.random.seed(SEED) + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + indices = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(indices, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled= y[indices] + return x_sampled, y_sampled + + +def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: + """Augment the mapping with extra symbols.""" + # Extra characters from the IAM dataset. + iam_characters = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # Also add special tokens for: + # - CTC blank token at index 0 + # - Start token at index 1 + # - End token at index 2 + # - Padding token at index 3 + # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! + return ["<b>", "<s>", "</s>", "<p>", *characters, *iam_characters] + + +if __name__ == "__main__": + load_print_info(EMNIST) diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py deleted file mode 100644 index 9884fdf..0000000 --- a/text_recognizer/datasets/emnist_dataset.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" - -import json -from pathlib import Path -from typing import Callable, Optional, Tuple, Union - -from loguru import logger -import numpy as np -from PIL import Image -import torch -from torch import Tensor -from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor - -from text_recognizer.datasets.dataset import Dataset -from text_recognizer.datasets.transforms import Transpose -from text_recognizer.datasets.util import DATA_DIRNAME - - -class EmnistDataset(Dataset): - """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" - - def __init__( - self, - pad_token: str = None, - train: bool = False, - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - seed: int = 4711, - ) -> None: - """Loads the dataset and the mappings. - - Args: - pad_token (str): The pad token symbol. Defaults to _. - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. - seed (int): Seed number. Defaults to 4711. - - """ - super().__init__( - train=train, - subsample_fraction=subsample_fraction, - transform=transform, - target_transform=target_transform, - pad_token=pad_token, - ) - - self.sample_to_balance = sample_to_balance - - # Have to transpose the emnist characters, ToTensor norms input between [0,1]. - if transform is None: - self.transform = Compose([Transpose(), ToTensor()]) - - self.target_transform = None - - self.seed = seed - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, Tensor]): The indices of the samples to fetch. - - Returns: - Tuple[Tensor, Tensor]: Data target tuple. - - """ - if torch.is_tensor(index): - index = index.tolist() - - data = self.data[index] - targets = self.targets[index] - - if self.transform: - data = self.transform(data) - - if self.target_transform: - targets = self.target_transform(targets) - - return data, targets - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return ( - "EMNIST Dataset\n" - f"Num classes: {self.num_classes}\n" - f"Input shape: {self.input_shape}\n" - f"Mapping: {self.mapper.mapping}\n" - ) - - def _sample_to_balance(self) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = self._data - y = self._targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - self._data = x_sampled - self._targets = y_sampled - - def load_or_generate_data(self) -> None: - """Fetch the EMNIST dataset.""" - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=self.train, - download=False, - transform=None, - target_transform=None, - ) - - self._data = dataset.data - self._targets = dataset.targets - - if self.sample_to_balance: - self._sample_to_balance() - - if self.subsample_fraction is not None: - self._subsample() diff --git a/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json deleted file mode 100644 index 2a0648a..0000000 --- a/text_recognizer/datasets/emnist_essentials.json +++ /dev/null @@ -1 +0,0 @@ -{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} |