summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
blob: 81268fbea7647d63a6fdf03ce6fab24ecf358f62 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Emnist dataset: black and white images of handwritten characters (Aa-Zz) and digits (0-9)."""

import json
from pathlib import Path
from typing import Callable, Optional, Tuple, Union

from loguru import logger
import numpy as np
from PIL import Image
import torch
from torch import Tensor
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, ToTensor

from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.util import DATA_DIRNAME


class Transpose:
    """Transposes the EMNIST image to the correct orientation."""

    def __call__(self, image: Image) -> np.ndarray:
        """Swaps axis."""
        return np.array(image).swapaxes(0, 1)


class EmnistDataset(Dataset):
    """This is a class for resampling and subsampling the PyTorch EMNIST dataset."""

    def __init__(
        self,
        train: bool = False,
        sample_to_balance: bool = False,
        subsample_fraction: float = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        seed: int = 4711,
    ) -> None:
        """Loads the dataset and the mappings.

        Args:
            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
            sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
            subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
            transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
            target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
            seed (int): Seed number. Defaults to 4711.

        """
        super().__init__(
            train=train,
            subsample_fraction=subsample_fraction,
            transform=transform,
            target_transform=target_transform,
        )

        self.sample_to_balance = sample_to_balance

        # Have to transpose the emnist characters, ToTensor norms input between [0,1].
        if transform is None:
            self.transform = Compose([Transpose(), ToTensor()])

        # The EMNIST dataset is already casted to tensors.
        self.target_transform = target_transform

        self.seed = seed

    def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
        """Fetches samples from the dataset.

        Args:
            index (Union[int, Tensor]): The indices of the samples to fetch.

        Returns:
            Tuple[Tensor, Tensor]: Data target tuple.

        """
        if torch.is_tensor(index):
            index = index.tolist()

        data = self.data[index]
        targets = self.targets[index]

        if self.transform:
            data = self.transform(data)

        if self.target_transform:
            targets = self.target_transform(targets)

        return data, targets

    def __repr__(self) -> str:
        """Returns information about the dataset."""
        return (
            "EMNIST Dataset\n"
            f"Num classes: {self.num_classes}\n"
            f"Input shape: {self.input_shape}\n"
            f"Mapping: {self.mapper.mapping}\n"
        )

    def _sample_to_balance(self) -> None:
        """Because the dataset is not balanced, we take at most the mean number of instances per class."""
        np.random.seed(self.seed)
        x = self._data
        y = self._targets
        num_to_sample = int(np.bincount(y.flatten()).mean())
        all_sampled_indices = []
        for label in np.unique(y.flatten()):
            inds = np.where(y == label)[0]
            sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
            all_sampled_indices.append(sampled_indices)
        indices = np.concatenate(all_sampled_indices)
        x_sampled = x[indices]
        y_sampled = y[indices]
        self._data = x_sampled
        self._targets = y_sampled

    def load_or_generate_data(self) -> None:
        """Fetch the EMNIST dataset."""
        dataset = EMNIST(
            root=DATA_DIRNAME,
            split="byclass",
            train=self.train,
            download=False,
            transform=None,
            target_transform=None,
        )

        self._data = dataset.data
        self._targets = dataset.targets

        if self.sample_to_balance:
            self._sample_to_balance()

        if self.subsample_fraction is not None:
            self._subsample()