diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 31 |
1 files changed, 13 insertions, 18 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 525df95..f3d65ee 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -260,21 +260,23 @@ class EmnistDataLoaders: """ self.splits = splits - self.sample_to_balance = sample_to_balance if subsample_fraction is not None: if not 0.0 < subsample_fraction < 1.0: raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - self.transform = transform - self.target_transform = target_transform + self.dataset_args = { + "sample_to_balance": sample_to_balance, + "subsample_fraction": subsample_fraction, + "transform": transform, + "target_transform": target_transform, + "seed": seed, + } self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers self.cuda = cuda - self.seed = seed - self._data_loaders = self._fetch_emnist_data_loaders() + self._data_loaders = self._load_data_loaders() def __repr__(self) -> str: """Returns information about the dataset.""" @@ -303,7 +305,7 @@ class EmnistDataLoaders: except KeyError: raise ValueError(f"Split {split} does not exist.") - def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: + def _load_data_loaders(self) -> Dict[str, DataLoader]: """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" data_loaders = {} @@ -311,18 +313,11 @@ class EmnistDataLoaders: if split in self.splits: if split == "train": - train = True + self.dataset_args["train"] = True else: - train = False - - emnist_dataset = EmnistDataset( - train=train, - sample_to_balance=self.sample_to_balance, - subsample_fraction=self.subsample_fraction, - transform=self.transform, - target_transform=self.target_transform, - seed=self.seed, - ) + self.dataset_args["train"] = False + + emnist_dataset = EmnistDataset(**self.dataset_args) emnist_dataset.load_emnist_dataset() |