diff options
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py new file mode 100644 index 0000000..f328a0f --- /dev/null +++ b/src/text_recognizer/datasets/dataset.py @@ -0,0 +1,124 @@ +"""Abstract dataset class.""" +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils import data +from torchvision.transforms import ToTensor + +from text_recognizer.datasets.util import EmnistMapper + + +class Dataset(data.Dataset): + """Abstract class for with common methods for all datasets.""" + + def __init__( + self, + train: bool, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + """Initialization of Dataset class. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + transform (Optional[Callable]): Transform(s) for input data. Defaults to None. + target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + self.train = train + self.split = "train" if self.train else "test" + + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + + self._mapper = EmnistMapper() + self._input_shape = self._mapper.input_shape + self._output_shape = self._mapper._num_classes + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = transform + if self.transform is None: + self.transform = ToTensor() + + self.target_transform = target_transform + if self.target_transform is None: + self.target_transform = torch.tensor + + self._data = None + self._targets = None + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self.data = self.data[:num_subsample] + self.targets = self.targets[:num_subsample] + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + raise NotImplementedError + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Raises: + NotImplementedError: If the method is not implemented in child class. + + """ + raise NotImplementedError + + def __repr__(self) -> str: + """Returns information about the dataset.""" + raise NotImplementedError |