diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-02 22:49:22 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-06-02 22:49:22 +0200 |
commit | 118c02c2730aaac2d10c2b9847339f6ffc83223f (patch) | |
tree | f7e1dc9a7159b63945a19d742a525f20c90c969e /src/text_recognizer/datasets | |
parent | 81d48b6a4da96696afde87a54f9fb7d89dd64cd2 (diff) |
Working on lab 1.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 60 |
2 files changed, 62 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py new file mode 100644 index 0000000..cbaf1d9 --- /dev/null +++ b/src/text_recognizer/datasets/__init__.py @@ -0,0 +1,2 @@ +"""Dataset modules.""" +# from .emnist_dataset import fetch_dataloader diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py new file mode 100644 index 0000000..67158ec --- /dev/null +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -0,0 +1,60 @@ +"""Fetches a DataLoader for EMNIST dataset with PyTorch.""" +from typing import Callable + +from torch.utils.data import DataLoader +from torchvision.datasets import EMNIST + + +def fetch_dataloader( + root: str, + split: str, + train: bool, + download: bool, + transform: Callable = None, + target_transform: Callable = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, +) -> DataLoader: + """Down/load the EMNIST dataset and return a PyTorch DataLoader. + + Args: + root (str): Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt + exist. + split (str): The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. + This argument specifies which one to use. + train (bool): If True, creates dataset from training.pt, otherwise from test.pt. + download (bool): If true, downloads the dataset from the internet and puts it in root directory. If + dataset is already downloaded, it is not downloaded again. + transform (Callable): A function/transform that takes in an PIL image and returns a transformed version. + E.g, transforms.RandomCrop. + target_transform (Callable): A function/transform that takes in the target and transforms it. + batch_size (int): How many samples per batch to load (the default is 128). + shuffle (bool): Set to True to have the data reshuffled at every epoch (the default is False). + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be loaded in + the main process (default: 0). + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them. + + Returns: + DataLoader: A PyTorch DataLoader with emnist characters. + + """ + dataset = EMNIST( + root=root, + split=split, + train=train, + download=download, + transform=transform, + target_transform=target_transform, + ) + + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=cuda, + ) + + return data_loader |