summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
blob: e7946055d2076dddec2e79709779a846178c7266 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Abstract dataset class."""
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor

import text_recognizer.datasets.transforms as transforms
from text_recognizer.datasets.util import EmnistMapper


class Dataset(data.Dataset):
    """Abstract class for with common methods for all datasets."""

    def __init__(
        self,
        train: bool,
        subsample_fraction: float = None,
        transform: Optional[List[Dict]] = None,
        target_transform: Optional[List[Dict]] = None,
        init_token: Optional[str] = None,
        pad_token: Optional[str] = None,
        eos_token: Optional[str] = None,
        lower: bool = False,
    ) -> None:
        """Initialization of Dataset class.

        Args:
            train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
            subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
            transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
            target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
            init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
            pad_token (Optional[str]): String representing the pad token. Defaults to None.
            eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
            lower (bool): Only use lower case letters. Defaults to False.

        Raises:
            ValueError: If subsample_fraction is not None and outside the range (0, 1).

        """
        self.train = train
        self.split = "train" if self.train else "test"

        if subsample_fraction is not None:
            if not 0.0 < subsample_fraction < 1.0:
                raise ValueError("The subsample fraction must be in (0, 1).")
        self.subsample_fraction = subsample_fraction

        self._mapper = EmnistMapper(
            init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
        )
        self._input_shape = self._mapper.input_shape
        self._output_shape = self._mapper._num_classes
        self.num_classes = self.mapper.num_classes

        # Set transforms.
        self.transform = self._configure_transform(transform)
        self.target_transform = self._configure_target_transform(target_transform)

        self._data = None
        self._targets = None

    def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
        transform_list = []
        if transform is not None:
            for t in transform:
                t_type = t["type"]
                t_args = t["args"] or {}
                transform_list.append(getattr(transforms, t_type)(**t_args))
        else:
            transform_list.append(ToTensor())
        return transforms.Compose(transform_list)

    def _configure_target_transform(
        self, target_transform: List[Dict]
    ) -> transforms.Compose:
        target_transform_list = [torch.tensor]
        if target_transform is not None:
            for t in target_transform:
                t_type = t["type"]
                t_args = t["args"] or {}
                target_transform_list.append(getattr(transforms, t_type)(**t_args))
        return transforms.Compose(target_transform_list)

    @property
    def data(self) -> Tensor:
        """The input data."""
        return self._data

    @property
    def targets(self) -> Tensor:
        """The target data."""
        return self._targets

    @property
    def input_shape(self) -> Tuple:
        """Input shape of the data."""
        return self._input_shape

    @property
    def output_shape(self) -> Tuple:
        """Output shape of the data."""
        return self._output_shape

    @property
    def mapper(self) -> EmnistMapper:
        """Returns the EmnistMapper."""
        return self._mapper

    @property
    def mapping(self) -> Dict:
        """Return EMNIST mapping from index to character."""
        return self._mapper.mapping

    @property
    def inverse_mapping(self) -> Dict:
        """Returns the inverse mapping from character to index."""
        return self.mapper.inverse_mapping

    def _subsample(self) -> None:
        """Only this fraction of the data will be loaded."""
        if self.subsample_fraction is None:
            return
        num_subsample = int(self.data.shape[0] * self.subsample_fraction)
        self._data = self.data[:num_subsample]
        self._targets = self.targets[:num_subsample]

    def __len__(self) -> int:
        """Returns the length of the dataset."""
        return len(self.data)

    def load_or_generate_data(self) -> None:
        """Load or generate dataset data."""
        raise NotImplementedError

    def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
        """Fetches samples from the dataset.

        Args:
            index (Union[int, torch.Tensor]): The indices of the samples to fetch.

        Raises:
            NotImplementedError: If the method is not implemented in child class.

        """
        raise NotImplementedError

    def __repr__(self) -> str:
        """Returns information about the dataset."""
        raise NotImplementedError