summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/dataset.py')
-rw-r--r--src/text_recognizer/datasets/dataset.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 95063bc..e794605 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -22,6 +22,7 @@ class Dataset(data.Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Initialization of Dataset class.
@@ -33,6 +34,7 @@ class Dataset(data.Dataset):
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): Only use lower case letters. Defaults to False.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
@@ -47,7 +49,7 @@ class Dataset(data.Dataset):
self.subsample_fraction = subsample_fraction
self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
)
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes