summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/datasets/dataset.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r--src/text_recognizer/datasets/dataset.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 05520e5..2de7f09 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -18,6 +18,9 @@ class Dataset(data.Dataset):
subsample_fraction: float = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
+ init_token: Optional[str] = None,
+ pad_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
) -> None:
"""Initialization of Dataset class.
@@ -26,12 +29,14 @@ class Dataset(data.Dataset):
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
+ pad_token (Optional[str]): String representing the pad token. Defaults to None.
+ eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
"""
-
self.train = train
self.split = "train" if self.train else "test"
@@ -40,19 +45,18 @@ class Dataset(data.Dataset):
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
- self._mapper = EmnistMapper()
+ self._mapper = EmnistMapper(
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ )
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ self.transform = transform if transform is not None else ToTensor()
+ self.target_transform = (
+ target_transform if target_transform is not None else torch.tensor
+ )
self._data = None
self._targets = None