diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 9 |
2 files changed, 7 insertions, 4 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index aec5bf9..795be90 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,4 @@ """Dataset modules.""" from .emnist_dataset import EmnistDataLoader + +__all__ = ["EmnistDataLoader"] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a17d7a9..b92b57d 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Type from loguru import logger import numpy as np @@ -102,21 +102,22 @@ class EmnistDataLoader: self.shuffle = shuffle self.num_workers = num_workers self.cuda = cuda + self.seed = seed self._data_loaders = self._fetch_emnist_data_loaders() @property def __name__(self) -> str: """Returns the name of the dataset.""" - return "EMNIST" + return "Emnist" - def __call__(self, split: str) -> Optional[DataLoader]: + def __call__(self, split: str) -> DataLoader: """Returns the `split` DataLoader. Args: split (str): The dataset split, i.e. train or val. Returns: - type: A PyTorch DataLoader. + DataLoader: A PyTorch DataLoader. Raises: ValueError: If the split does not exist. |