diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/datasets/dataset.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/datasets/dataset.py')
-rw-r--r-- | text_recognizer/datasets/dataset.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py new file mode 100644 index 0000000..e794605 --- /dev/null +++ b/text_recognizer/datasets/dataset.py @@ -0,0 +1,152 @@ +"""Abstract dataset class.""" +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.utils import data +from torchvision.transforms import ToTensor + +import text_recognizer.datasets.transforms as transforms +from text_recognizer.datasets.util import EmnistMapper + + +class Dataset(data.Dataset): + """Abstract class for with common methods for all datasets.""" + + def __init__( + self, + train: bool, + subsample_fraction: float = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, + lower: bool = False, + ) -> None: + """Initialization of Dataset class. + + Args: + train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): Only use lower case letters. Defaults to False. + + Raises: + ValueError: If subsample_fraction is not None and outside the range (0, 1). + + """ + self.train = train + self.split = "train" if self.train else "test" + + if subsample_fraction is not None: + if not 0.0 < subsample_fraction < 1.0: + raise ValueError("The subsample fraction must be in (0, 1).") + self.subsample_fraction = subsample_fraction + + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower + ) + self._input_shape = self._mapper.input_shape + self._output_shape = self._mapper._num_classes + self.num_classes = self.mapper.num_classes + + # Set transforms. + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) + + self._data = None + self._targets = None + + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + + @property + def data(self) -> Tensor: + """The input data.""" + return self._data + + @property + def targets(self) -> Tensor: + """The target data.""" + return self._targets + + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + + @property + def output_shape(self) -> Tuple: + """Output shape of the data.""" + return self._output_shape + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the inverse mapping from character to index.""" + return self.mapper.inverse_mapping + + def _subsample(self) -> None: + """Only this fraction of the data will be loaded.""" + if self.subsample_fraction is None: + return + num_subsample = int(self.data.shape[0] * self.subsample_fraction) + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] + + def __len__(self) -> int: + """Returns the length of the dataset.""" + return len(self.data) + + def load_or_generate_data(self) -> None: + """Load or generate dataset data.""" + raise NotImplementedError + + def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: + """Fetches samples from the dataset. + + Args: + index (Union[int, torch.Tensor]): The indices of the samples to fetch. + + Raises: + NotImplementedError: If the method is not implemented in child class. + + """ + raise NotImplementedError + + def __repr__(self) -> str: + """Returns information about the dataset.""" + raise NotImplementedError |