summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/data_loader.py
blob: fd559349380f3fbb66d4df94b01d550e675f6260 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""Data loader collection."""

from typing import Dict

from torch.utils.data import DataLoader

from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader


def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
    """Fetches the specified PyTorch data loader."""
    if data_loader_args.pop("name").lower() == "emnist":
        return fetch_emnist_data_loader(data_loader_args)
    else:
        raise NotImplementedError