blob: fd559349380f3fbb66d4df94b01d550e675f6260 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
"""Data loader collection."""
from typing import Dict
from torch.utils.data import DataLoader
from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader
def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
"""Fetches the specified PyTorch data loader."""
if data_loader_args.pop("name").lower() == "emnist":
return fetch_emnist_data_loader(data_loader_args)
else:
raise NotImplementedError
|