summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py9
2 files changed, 7 insertions, 4 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index aec5bf9..795be90 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,4 @@
"""Dataset modules."""
from .emnist_dataset import EmnistDataLoader
+
+__all__ = ["EmnistDataLoader"]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index a17d7a9..b92b57d 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -2,7 +2,7 @@
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional
+from typing import Callable, Dict, List, Optional, Type
from loguru import logger
import numpy as np
@@ -102,21 +102,22 @@ class EmnistDataLoader:
self.shuffle = shuffle
self.num_workers = num_workers
self.cuda = cuda
+ self.seed = seed
self._data_loaders = self._fetch_emnist_data_loaders()
@property
def __name__(self) -> str:
"""Returns the name of the dataset."""
- return "EMNIST"
+ return "Emnist"
- def __call__(self, split: str) -> Optional[DataLoader]:
+ def __call__(self, split: str) -> DataLoader:
"""Returns the `split` DataLoader.
Args:
split (str): The dataset split, i.e. train or val.
Returns:
- type: A PyTorch DataLoader.
+ DataLoader: A PyTorch DataLoader.
Raises:
ValueError: If the split does not exist.