diff options
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index f328a0f..05520e5 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -23,7 +23,7 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. @@ -31,6 +31,7 @@ class Dataset(data.Dataset): ValueError: If subsample_fraction is not None and outside the range (0, 1). """ + self.train = train self.split = "train" if self.train else "test" @@ -96,8 +97,8 @@ class Dataset(data.Dataset): if self.subsample_fraction is None: return num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] def __len__(self) -> int: """Returns the length of the dataset.""" |