summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py32
1 files changed, 23 insertions, 9 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 49ebad3..0715aae 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -152,8 +152,7 @@ class EmnistDataset(Dataset):
"""Loads the dataset and the mappings.
Args:
- train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to
- False.
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False.
subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
@@ -181,17 +180,37 @@ class EmnistDataset(Dataset):
self.seed = seed
self._mapper = EmnistMapper()
- self.input_shape = self._mapper.input_shape
+ self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
# Load dataset.
- self.data, self.targets = self.load_emnist_dataset()
+ self._data, self._targets = self.load_emnist_dataset()
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
@property
def mapper(self) -> EmnistMapper:
"""Returns the EmnistMapper."""
return self._mapper
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -220,11 +239,6 @@ class EmnistDataset(Dataset):
return data, targets
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistDataset"
-
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (