summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-06-03 23:50:13 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-06-03 23:50:13 +0200
commit04cbe33804dddba9f8eed6b930bf653a0473392a (patch)
tree265014a644c73d20d6db061fc2d27eacf00c86d6 /src/text_recognizer/datasets
parentab9af6bdc3274455c7206027f1828c7a609bab11 (diff)
EMNIST dataset working.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 84d316a..cf81b5d 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,10 +1,22 @@
"""Fetches a DataLoader with the EMNIST dataset with PyTorch."""
+from pathlib import Path
from typing import Callable
+import click
+from loguru import logger
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
+@click.command()
+@click.option("--split", "-s", default="byclass")
+def download_emnist(split: str) -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ data_dir = Path(__file__).resolve().parents[3] / "data"
+ logger.debug(f"Data directory is: {data_dir}")
+ EMNIST(root=data_dir, split=split, download=True)
+
+
def fetch_dataloader(
root: str,
split: str,