summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py31
1 files changed, 13 insertions, 18 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 525df95..f3d65ee 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -260,21 +260,23 @@ class EmnistDataLoaders:
"""
self.splits = splits
- self.sample_to_balance = sample_to_balance
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
- self.transform = transform
- self.target_transform = target_transform
+ self.dataset_args = {
+ "sample_to_balance": sample_to_balance,
+ "subsample_fraction": subsample_fraction,
+ "transform": transform,
+ "target_transform": target_transform,
+ "seed": seed,
+ }
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.cuda = cuda
- self.seed = seed
- self._data_loaders = self._fetch_emnist_data_loaders()
+ self._data_loaders = self._load_data_loaders()
def __repr__(self) -> str:
"""Returns information about the dataset."""
@@ -303,7 +305,7 @@ class EmnistDataLoaders:
except KeyError:
raise ValueError(f"Split {split} does not exist.")
- def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
+ def _load_data_loaders(self) -> Dict[str, DataLoader]:
"""Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
data_loaders = {}
@@ -311,18 +313,11 @@ class EmnistDataLoaders:
if split in self.splits:
if split == "train":
- train = True
+ self.dataset_args["train"] = True
else:
- train = False
-
- emnist_dataset = EmnistDataset(
- train=train,
- sample_to_balance=self.sample_to_balance,
- subsample_fraction=self.subsample_fraction,
- transform=self.transform,
- target_transform=self.target_transform,
- seed=self.seed,
- )
+ self.dataset_args["train"] = False
+
+ emnist_dataset = EmnistDataset(**self.dataset_args)
emnist_dataset.load_emnist_dataset()