From 4d7713746eb936832e84852e90292936b933e87d Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 22 Oct 2020 22:45:58 +0200 Subject: Transfomer added, many other changes. --- src/text_recognizer/datasets/dataset.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'src/text_recognizer/datasets/dataset.py') diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 05520e5..2de7f09 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -18,6 +18,9 @@ class Dataset(data.Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Initialization of Dataset class. @@ -26,12 +29,14 @@ class Dataset(data.Dataset): subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). """ - self.train = train self.split = "train" if self.train else "test" @@ -40,19 +45,18 @@ class Dataset(data.Dataset): raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction - self._mapper = EmnistMapper() + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token + ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + self.transform = transform if transform is not None else ToTensor() + self.target_transform = ( + target_transform if target_transform is not None else torch.tensor + ) self._data = None self._targets = None -- cgit v1.2.3-70-g09d2