From 3b06ef615a8db67a03927576e0c12fbfb2501f5f Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 14 Sep 2020 22:15:47 +0200 Subject: Fixed CTC loss. --- src/text_recognizer/datasets/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'src/text_recognizer/datasets/dataset.py') diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index f328a0f..05520e5 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -23,7 +23,7 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. - subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. + subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. @@ -31,6 +31,7 @@ class Dataset(data.Dataset): ValueError: If subsample_fraction is not None and outside the range (0, 1). """ + self.train = train self.split = "train" if self.train else "test" @@ -96,8 +97,8 @@ class Dataset(data.Dataset): if self.subsample_fraction is None: return num_subsample = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_subsample] - self.targets = self.targets[:num_subsample] + self._data = self.data[:num_subsample] + self._targets = self.targets[:num_subsample] def __len__(self) -> int: """Returns the length of the dataset.""" -- cgit v1.2.3-70-g09d2