From beeaef529e7c893a3475fe27edc880e283373725 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 8 Nov 2020 12:41:04 +0100 Subject: Trying to get the CNNTransformer to work, but it is hard. --- src/text_recognizer/datasets/emnist_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'src/text_recognizer/datasets/emnist_dataset.py') diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a8901d6..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,6 +56,8 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) + self.target_transform = None + self.seed = seed def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: -- cgit v1.2.3-70-g09d2