diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/datasets/dataset.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 152 |
1 files changed, 0 insertions, 152 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py deleted file mode 100644 index e794605..0000000 --- a/src/text_recognizer/datasets/dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Abstract dataset class.""" -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch import Tensor -from torch.utils import data -from torchvision.transforms import ToTensor - -import text_recognizer.datasets.transforms as transforms -from text_recognizer.datasets.util import EmnistMapper - - -class Dataset(data.Dataset): - """Abstract class for with common methods for all datasets.""" - - def __init__( - self, - train: bool, - subsample_fraction: float = None, - transform: Optional[List[Dict]] = None, - target_transform: Optional[List[Dict]] = None, - init_token: Optional[str] = None, - pad_token: Optional[str] = None, - eos_token: Optional[str] = None, - lower: bool = False, - ) -> None: - """Initialization of Dataset class. - - Args: - train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. - target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. - init_token (Optional[str]): String representing the start of sequence token. Defaults to None. - pad_token (Optional[str]): String representing the pad token. Defaults to None. - eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. - lower (bool): Only use lower case letters. Defaults to False. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.train = train - self.split = "train" if self.train else "test" - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - - self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower - ) - self._input_shape = self._mapper.input_shape - self._output_shape = self._mapper._num_classes - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = self._configure_transform(transform) - self.target_transform = self._configure_target_transform(target_transform) - - self._data = None - self._targets = None - - def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: - transform_list = [] - if transform is not None: - for t in transform: - t_type = t["type"] - t_args = t["args"] or {} - transform_list.append(getattr(transforms, t_type)(**t_args)) - else: - transform_list.append(ToTensor()) - return transforms.Compose(transform_list) - - def _configure_target_transform( - self, target_transform: List[Dict] - ) -> transforms.Compose: - target_transform_list = [torch.tensor] - if target_transform is not None: - for t in target_transform: - t_type = t["type"] - t_args = t["args"] or {} - target_transform_list.append(getattr(transforms, t_type)(**t_args)) - return transforms.Compose(target_transform_list) - - @property - def data(self) -> Tensor: - """The input data.""" - return self._data - - @property - def targets(self) -> Tensor: - """The target data.""" - return self._targets - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - @property - def output_shape(self) -> Tuple: - """Output shape of the data.""" - return self._output_shape - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property - def inverse_mapping(self) -> Dict: - """Returns the inverse mapping from character to index.""" - return self.mapper.inverse_mapping - - def _subsample(self) -> None: - """Only this fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self._data = self.data[:num_subsample] - self._targets = self.targets[:num_subsample] - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) - - def load_or_generate_data(self) -> None: - """Load or generate dataset data.""" - raise NotImplementedError - - def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: - """Fetches samples from the dataset. - - Args: - index (Union[int, torch.Tensor]): The indices of the samples to fetch. - - Raises: - NotImplementedError: If the method is not implemented in child class. - - """ - raise NotImplementedError - - def __repr__(self) -> str: - """Returns information about the dataset.""" - raise NotImplementedError |