diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-03 23:50:13 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-03 23:50:13 +0200 |
commit | 04cbe33804dddba9f8eed6b930bf653a0473392a (patch) | |
tree | 265014a644c73d20d6db061fc2d27eacf00c86d6 /src/text_recognizer/datasets | |
parent | ab9af6bdc3274455c7206027f1828c7a609bab11 (diff) |
EMNIST dataset working.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 84d316a..cf81b5d 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,10 +1,22 @@ """Fetches a DataLoader with the EMNIST dataset with PyTorch.""" +from pathlib import Path from typing import Callable +import click +from loguru import logger from torch.utils.data import DataLoader from torchvision.datasets import EMNIST +@click.command() +@click.option("--split", "-s", default="byclass") +def download_emnist(split: str) -> None: + """Download the EMNIST dataset via the PyTorch class.""" + data_dir = Path(__file__).resolve().parents[3] / "data" + logger.debug(f"Data directory is: {data_dir}") + EMNIST(root=data_dir, split=split, download=True) + + def fetch_dataloader( root: str, split: str, |