diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index d01dcee..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,8 +56,7 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) - # The EMNIST dataset is already casted to tensors. - self.target_transform = target_transform + self.target_transform = None self.seed = seed |