summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r--src/text_recognizer/datasets/dataset.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
new file mode 100644
index 0000000..f328a0f
--- /dev/null
+++ b/src/text_recognizer/datasets/dataset.py
@@ -0,0 +1,124 @@
+"""Abstract dataset class."""
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils import data
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class Dataset(data.Dataset):
+ """Abstract class for with common methods for all datasets."""
+
+ def __init__(
+ self,
+ train: bool,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ """Initialization of Dataset class.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+ self.train = train
+ self.split = "train" if self.train else "test"
+
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+
+ self._mapper = EmnistMapper()
+ self._input_shape = self._mapper.input_shape
+ self._output_shape = self._mapper._num_classes
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self._data = None
+ self._targets = None
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_subsample]
+ self.targets = self.targets[:num_subsample]
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ raise NotImplementedError
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Raises:
+ NotImplementedError: If the method is not implemented in child class.
+
+ """
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ raise NotImplementedError