1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
"""Abstract dataset class."""
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor
from text_recognizer.datasets.util import EmnistMapper
class Dataset(data.Dataset):
"""Abstract class for with common methods for all datasets."""
def __init__(
self,
train: bool,
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""Initialization of Dataset class.
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
self.train = train
self.split = "train" if self.train else "test"
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
self._mapper = EmnistMapper()
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
self.num_classes = self.mapper.num_classes
# Set transforms.
self.transform = transform
if self.transform is None:
self.transform = ToTensor()
self.target_transform = target_transform
if self.target_transform is None:
self.target_transform = torch.tensor
self._data = None
self._targets = None
@property
def data(self) -> Tensor:
"""The input data."""
return self._data
@property
def targets(self) -> Tensor:
"""The target data."""
return self._targets
@property
def input_shape(self) -> Tuple:
"""Input shape of the data."""
return self._input_shape
@property
def output_shape(self) -> Tuple:
"""Output shape of the data."""
return self._output_shape
@property
def mapper(self) -> EmnistMapper:
"""Returns the EmnistMapper."""
return self._mapper
@property
def mapping(self) -> Dict:
"""Return EMNIST mapping from index to character."""
return self._mapper.mapping
@property
def inverse_mapping(self) -> Dict:
"""Returns the inverse mapping from character to index."""
return self.mapper.inverse_mapping
def _subsample(self) -> None:
"""Only this fraction of the data will be loaded."""
if self.subsample_fraction is None:
return
num_subsample = int(self.data.shape[0] * self.subsample_fraction)
self._data = self.data[:num_subsample]
self._targets = self.targets[:num_subsample]
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
def load_or_generate_data(self) -> None:
"""Load or generate dataset data."""
raise NotImplementedError
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
index (Union[int, torch.Tensor]): The indices of the samples to fetch.
Raises:
NotImplementedError: If the method is not implemented in child class.
"""
raise NotImplementedError
def __repr__(self) -> str:
"""Returns information about the dataset."""
raise NotImplementedError
|