summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 84d316a..cf81b5d 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,10 +1,22 @@
"""Fetches a DataLoader with the EMNIST dataset with PyTorch."""
+from pathlib import Path
from typing import Callable
+import click
+from loguru import logger
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
+@click.command()
+@click.option("--split", "-s", default="byclass")
+def download_emnist(split: str) -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ data_dir = Path(__file__).resolve().parents[3] / "data"
+ logger.debug(f"Data directory is: {data_dir}")
+ EMNIST(root=data_dir, split=split, download=True)
+
+
def fetch_dataloader(
root: str,
split: str,