diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 84d316a..cf81b5d 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,10 +1,22 @@ """Fetches a DataLoader with the EMNIST dataset with PyTorch.""" +from pathlib import Path from typing import Callable +import click +from loguru import logger from torch.utils.data import DataLoader from torchvision.datasets import EMNIST +@click.command() +@click.option("--split", "-s", default="byclass") +def download_emnist(split: str) -> None: + """Download the EMNIST dataset via the PyTorch class.""" + data_dir = Path(__file__).resolve().parents[3] / "data" + logger.debug(f"Data directory is: {data_dir}") + EMNIST(root=data_dir, split=split, download=True) + + def fetch_dataloader( root: str, split: str, |