diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-14 22:15:47 +0200 |
commit | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch) | |
tree | e1c2b1289971c8480327408de46152481e99b539 /src/text_recognizer/datasets/dataset.py | |
parent | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff) |
Fixed CTC loss.
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index f328a0f..05520e5 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -23,7 +23,7 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. @@ -31,6 +31,7 @@ class Dataset(data.Dataset): ValueError: If subsample_fraction is not None and outside the range (0, 1). """ + self.train = train self.split = "train" if self.train else "test" @@ -96,8 +97,8 @@ class Dataset(data.Dataset): if self.subsample_fraction is None: return num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] def __len__(self) -> int: """Returns the length of the dataset.""" |