summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
blob: 05520e5b561d75c958ad85e698a764804839207c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Abstract dataset class."""
from typing import Callable, Dict, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor

from text_recognizer.datasets.util import EmnistMapper


class Dataset(data.Dataset):
    """Abstract class for with common methods for all datasets."""

    def __init__(
        self,
        train: bool,
        subsample_fraction: float = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        """Initialization of Dataset class.

        Args:
            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
            subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
            transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
            target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.

        Raises:
            ValueError: If subsample_fraction is not None and outside the range (0, 1).

        """

        self.train = train
        self.split = "train" if self.train else "test"

        if subsample_fraction is not None:
            if not 0.0 < subsample_fraction < 1.0:
                raise ValueError("The subsample fraction must be in (0, 1).")
        self.subsample_fraction = subsample_fraction

        self._mapper = EmnistMapper()
        self._input_shape = self._mapper.input_shape
        self._output_shape = self._mapper._num_classes
        self.num_classes = self.mapper.num_classes

        # Set transforms.
        self.transform = transform
        if self.transform is None:
            self.transform = ToTensor()

        self.target_transform = target_transform
        if self.target_transform is None:
            self.target_transform = torch.tensor

        self._data = None
        self._targets = None

    @property
    def data(self) -> Tensor:
        """The input data."""
        return self._data

    @property
    def targets(self) -> Tensor:
        """The target data."""
        return self._targets

    @property
    def input_shape(self) -> Tuple:
        """Input shape of the data."""
        return self._input_shape

    @property
    def output_shape(self) -> Tuple:
        """Output shape of the data."""
        return self._output_shape

    @property
    def mapper(self) -> EmnistMapper:
        """Returns the EmnistMapper."""
        return self._mapper

    @property
    def mapping(self) -> Dict:
        """Return EMNIST mapping from index to character."""
        return self._mapper.mapping

    @property
    def inverse_mapping(self) -> Dict:
        """Returns the inverse mapping from character to index."""
        return self.mapper.inverse_mapping

    def _subsample(self) -> None:
        """Only this fraction of the data will be loaded."""
        if self.subsample_fraction is None:
            return
        num_subsample = int(self.data.shape[0] * self.subsample_fraction)
        self._data = self.data[:num_subsample]
        self._targets = self.targets[:num_subsample]

    def __len__(self) -> int:
        """Returns the length of the dataset."""
        return len(self.data)

    def load_or_generate_data(self) -> None:
        """Load or generate dataset data."""
        raise NotImplementedError

    def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
        """Fetches samples from the dataset.

        Args:
            index (Union[int, torch.Tensor]): The indices of the samples to fetch.

        Raises:
            NotImplementedError: If the method is not implemented in child class.

        """
        raise NotImplementedError

    def __repr__(self) -> str:
        """Returns information about the dataset."""
        raise NotImplementedError