summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r--src/text_recognizer/datasets/dataset.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index f328a0f..05520e5 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -23,7 +23,7 @@ class Dataset(data.Dataset):
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
- subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
@@ -31,6 +31,7 @@ class Dataset(data.Dataset):
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
+
self.train = train
self.split = "train" if self.train else "test"
@@ -96,8 +97,8 @@ class Dataset(data.Dataset):
if self.subsample_fraction is None:
return
num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_subsample]
- self.targets = self.targets[:num_subsample]
+ self._data = self.data[:num_subsample]
+ self._targets = self.targets[:num_subsample]
def __len__(self) -> int:
"""Returns the length of the dataset."""