diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/data_loader.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 177 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_essentials.json | 1 |
4 files changed, 147 insertions, 48 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index cbaf1d9..929cfb9 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,2 @@ """Dataset modules.""" -# from .emnist_dataset import fetch_dataloader +from .data_loader import fetch_data_loader diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py new file mode 100644 index 0000000..fd55934 --- /dev/null +++ b/src/text_recognizer/datasets/data_loader.py @@ -0,0 +1,15 @@ +"""Data loader collection.""" + +from typing import Dict + +from torch.utils.data import DataLoader + +from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader + + +def fetch_data_loader(data_loader_args: Dict) -> DataLoader: + """Fetches the specified PyTorch data loader.""" + if data_loader_args.pop("name").lower() == "emnist": + return fetch_emnist_data_loader(data_loader_args) + else: + raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 204faeb..f9c8ffa 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -1,72 +1,155 @@ """Fetches a PyTorch DataLoader with the EMNIST dataset.""" + +import json from pathlib import Path -from typing import Callable +from typing import Callable, Dict, List, Optional -import click from loguru import logger +import numpy as np +from PIL import Image from torch.utils.data import DataLoader from torchvision.datasets import EMNIST +from torchvision.transforms import Compose, ToTensor + + +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" + + +class Transpose: + """Transposes the EMNIST image to the correct orientation.""" + + def __call__(self, image: Image) -> np.ndarray: + """Swaps axis.""" + return np.array(image).swapaxes(0, 1) + + +def save_emnist_essentials(emnsit_dataset: EMNIST) -> None: + """Extract and saves EMNIST essentials.""" + labels = emnsit_dataset.classes + labels.sort() + mapping = [(i, str(label)) for i, label in enumerate(labels)] + essentials = { + "mapping": mapping, + "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + } + logger.info("Saving emnist essentials...") + with open(ESSENTIALS_FILENAME, "w") as f: + json.dump(essentials, f) -@click.command() -@click.option("--split", "-s", default="byclass") -def download_emnist(split: str) -> None: +def download_emnist() -> None: """Download the EMNIST dataset via the PyTorch class.""" - data_dir = Path(__file__).resolve().parents[3] / "data" - logger.debug(f"Data directory is: {data_dir}") - EMNIST(root=data_dir, split=split, download=True) + logger.info(f"Data directory is: {DATA_DIRNAME}") + dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) + save_emnist_essentials(dataset) -def fetch_dataloader( - root: str, +def load_emnist_mapping() -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return dict(essentials["mapping"]) + + +def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(seed) + x = dataset.data + y = dataset.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_inds = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_inds.append(sampled_inds) + ind = np.concatenate(all_sampled_inds) + x_sampled = x[ind] + y_sampled = y[ind] + dataset.data = x_sampled + dataset.targets = y_sampled + + +def fetch_emnist_dataset( split: str, train: bool, - download: bool, - transform: Callable = None, - target_transform: Callable = None, + sample_to_balance: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +) -> EMNIST: + """Fetch the EMNIST dataset.""" + if transform is None: + transform = Compose([Transpose(), ToTensor()]) + + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=train, + download=False, + transform=transform, + target_transform=target_transform, + ) + + if sample_to_balance and split == "byclass": + _sample_to_balance(dataset) + + return dataset + + +def fetch_emnist_data_loader( + splits: List[str], + sample_to_balance: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, batch_size: int = 128, shuffle: bool = False, num_workers: int = 0, cuda: bool = True, -) -> DataLoader: - """Down/load the EMNIST dataset and return a PyTorch DataLoader. +) -> Dict[DataLoader]: + """Fetches the EMNIST dataset and return a PyTorch DataLoader. Args: - root (str): Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt - exist. - split (str): The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist. - This argument specifies which one to use. - train (bool): If True, creates dataset from training.pt, otherwise from test.pt. - download (bool): If true, downloads the dataset from the internet and puts it in root directory. If - dataset is already downloaded, it is not downloaded again. - transform (Callable): A function/transform that takes in an PIL image and returns a transformed version. - E.g, transforms.RandomCrop. - target_transform (Callable): A function/transform that takes in the target and transforms it. - batch_size (int): How many samples per batch to load (the default is 128). - shuffle (bool): Set to True to have the data reshuffled at every epoch (the default is False). - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be loaded in - the main process (default: 0). - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them. + splits (List[str]): One or both of the dataset splits "train" and "val". + sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. + Defaults to False. + transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a + transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (Optional[Callable]): A function/transform that takes in the target and transforms + it. + Defaults to None. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. Returns: - DataLoader: A PyTorch DataLoader with emnist characters. + Dict: A dict containing PyTorch DataLoader(s) with emnist characters. """ - dataset = EMNIST( - root=root, - split=split, - train=train, - download=download, - transform=transform, - target_transform=target_transform, - ) + data_loaders = {} - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) + for split in ["train", "val"]: + if split in splits: + + if split == "train": + train = True + else: + train = False + + dataset = fetch_emnist_dataset( + split, train, sample_to_balance, transform, target_transform + ) + + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=cuda, + ) + + data_loaders[split] = data_loader - return data_loader + return data_loaders diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/src/text_recognizer/datasets/emnist_essentials.json new file mode 100644 index 0000000..2a0648a --- /dev/null +++ b/src/text_recognizer/datasets/emnist_essentials.json @@ -0,0 +1 @@ +{"mapping": [[0, "0"], [1, "1"], [2, "2"], [3, "3"], [4, "4"], [5, "5"], [6, "6"], [7, "7"], [8, "8"], [9, "9"], [10, "A"], [11, "B"], [12, "C"], [13, "D"], [14, "E"], [15, "F"], [16, "G"], [17, "H"], [18, "I"], [19, "J"], [20, "K"], [21, "L"], [22, "M"], [23, "N"], [24, "O"], [25, "P"], [26, "Q"], [27, "R"], [28, "S"], [29, "T"], [30, "U"], [31, "V"], [32, "W"], [33, "X"], [34, "Y"], [35, "Z"], [36, "a"], [37, "b"], [38, "c"], [39, "d"], [40, "e"], [41, "f"], [42, "g"], [43, "h"], [44, "i"], [45, "j"], [46, "k"], [47, "l"], [48, "m"], [49, "n"], [50, "o"], [51, "p"], [52, "q"], [53, "r"], [54, "s"], [55, "t"], [56, "u"], [57, "v"], [58, "w"], [59, "x"], [60, "y"], [61, "z"]], "input_shape": [28, 28]} |