diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 279 |
1 files changed, 205 insertions, 74 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index b92b57d..525df95 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,29 +1,23 @@ -"""Fetches a PyTorch DataLoader with the EMNIST dataset.""" +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" import json from pathlib import Path -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader +import torch +from torch.utils.data import DataLoader, Dataset from torchvision.datasets import EMNIST -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Normalize, ToTensor +from text_recognizer.datasets.util import Transpose DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -class Transpose: - """Transposes the EMNIST image to the correct orientation.""" - - def __call__(self, image: Image) -> np.ndarray: - """Swaps axis.""" - return np.array(image).swapaxes(0, 1) - - def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes @@ -45,14 +39,187 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict[int, str]: +def _load_emnist_essentials() -> Dict: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) - return dict(essentials["mapping"]) + return essentials + + +def _augment_emnist_mapping(mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to + False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + + self.train = train + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + self.transform = transform + if self.transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = target_transform + self.seed = seed + + # Load dataset infromation. + essentials = _load_emnist_essentials() + self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) + self.inverse_mapping = {v: k for k, v in self.mapping.items()} + self.num_classes = len(self.mapping) + self.input_shape = essentials["input_shape"] + + # Placeholders + self.data = None + self.targets = None + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def __getitem__( + self, index: Union[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Mapping: {self.mapping}\n" + f"Input shape: {self.input_shape}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self.data + y = self.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self.data = x_sampled + self.targets = y_sampled + + def _subsample(self) -> None: + """Subsamples the dataset to the specified fraction.""" + x = self.data + y = self.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + self.data = x_sampled + self.targets = y_sampled + + def load_emnist_dataset(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self.data = dataset.data + self.targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() -class EmnistDataLoader: +class EmnistDataLoaders: """Class for Emnist DataLoaders.""" def __init__( @@ -68,7 +235,7 @@ class EmnistDataLoader: cuda: bool = True, seed: int = 4711, ) -> None: - """Fetches DataLoaders. + """Fetches DataLoaders for given split(s). Args: splits (List[str]): One or both of the dataset splits "train" and "val". @@ -88,13 +255,17 @@ class EmnistDataLoader: them. Defaults to True. seed (int): Seed for sampling. + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + """ self.splits = splits self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: - assert ( - 0.0 < subsample_fraction < 1.0 - ), " The subsample fraction must be in (0, 1)." + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction self.transform = transform self.target_transform = target_transform @@ -105,6 +276,10 @@ class EmnistDataLoader: self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() + def __repr__(self) -> str: + """Returns information about the dataset.""" + return self._data_loaders[self.splits[0]].dataset.__repr__() + @property def __name__(self) -> str: """Returns the name of the dataset.""" @@ -128,59 +303,6 @@ class EmnistDataLoader: except KeyError: raise ValueError(f"Split {split} does not exist.") - def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(self.seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_indices = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_indices.append(sampled_indices) - indices = np.concatenate(all_sampled_indices) - x_sampled = x[indices] - y_sampled = y[indices] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _subsample(self, dataset: type = EMNIST) -> EMNIST: - """Subsamples the dataset to the specified fraction.""" - x = dataset.data - y = dataset.targets - num_samples = int(x.shape[0] * self.subsample_fraction) - x_sampled = x[:num_samples] - y_sampled = y[:num_samples] - dataset.data = x_sampled - dataset.targets = y_sampled - - return dataset - - def _fetch_emnist_dataset(self, train: bool) -> EMNIST: - """Fetch the EMNIST dataset.""" - if self.transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=self.target_transform, - ) - - if self.sample_to_balance: - dataset = self._sample_to_balance(dataset) - - if self.subsample_fraction is not None: - dataset = self._subsample(dataset) - - return dataset - def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" data_loaders = {} @@ -193,10 +315,19 @@ class EmnistDataLoader: else: train = False - dataset = self._fetch_emnist_dataset(train) + emnist_dataset = EmnistDataset( + train=train, + sample_to_balance=self.sample_to_balance, + subsample_fraction=self.subsample_fraction, + transform=self.transform, + target_transform=self.target_transform, + seed=self.seed, + ) + + emnist_dataset.load_emnist_dataset() data_loader = DataLoader( - dataset=dataset, + dataset=emnist_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, |