diff options
Diffstat (limited to 'text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | text_recognizer/datasets/emnist_dataset.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py new file mode 100644 index 0000000..9884fdf --- /dev/null +++ b/text_recognizer/datasets/emnist_dataset.py @@ -0,0 +1,131 @@ +"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9).""" + +import json +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +from loguru import logger +import numpy as np +from PIL import Image +import torch +from torch import Tensor +from torchvision.datasets import EMNIST +from torchvision.transforms import Compose, ToTensor + +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.transforms import Transpose +from text_recognizer.datasets.util import DATA_DIRNAME + + +class EmnistDataset(Dataset): + """This is a class for resampling and subsampling the PyTorch EMNIST dataset.""" + + def __init__( + self, + pad_token: str = None, + train: bool = False, + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + seed: int = 4711, + ) -> None: + """Loads the dataset and the mappings. + + Args: + pad_token (str): The pad token symbol. Defaults to _. + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + seed (int): Seed number. Defaults to 4711. + + """ + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + pad_token=pad_token, + ) + + self.sample_to_balance = sample_to_balance + + # Have to transpose the emnist characters, ToTensor norms input between [0,1]. + if transform is None: + self.transform = Compose([Transpose(), ToTensor()]) + + self.target_transform = None + + self.seed = seed + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, Tensor]): The indices of the samples to fetch. + + Returns: + Tuple[Tensor, Tensor]: Data target tuple. + + """ + if torch.is_tensor(index): + index = index.tolist() + + data = self.data[index] + targets = self.targets[index] + + if self.transform: + data = self.transform(data) + + if self.target_transform: + targets = self.target_transform(targets) + + return data, targets + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return ( + "EMNIST Dataset\n" + f"Num classes: {self.num_classes}\n" + f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" + ) + + def _sample_to_balance(self) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = self._data + y = self._targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + self._data = x_sampled + self._targets = y_sampled + + def load_or_generate_data(self) -> None: + """Fetch the EMNIST dataset.""" + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=self.train, + download=False, + transform=None, + target_transform=None, + ) + + self._data = dataset.data + self._targets = dataset.targets + + if self.sample_to_balance: + self._sample_to_balance() + + if self.subsample_fraction is not None: + self._subsample() |