summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py177
1 files changed, 130 insertions, 47 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 204faeb..f9c8ffa 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -1,72 +1,155 @@
"""Fetches a PyTorch DataLoader with the EMNIST dataset."""
+
+import json
from pathlib import Path
-from typing import Callable
+from typing import Callable, Dict, List, Optional
-import click
from loguru import logger
+import numpy as np
+from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import EMNIST
+from torchvision.transforms import Compose, ToTensor
+
+
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+
+
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
+
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
+
+
+def save_emnist_essentials(emnsit_dataset: EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
-@click.command()
-@click.option("--split", "-s", default="byclass")
-def download_emnist(split: str) -> None:
+def download_emnist() -> None:
"""Download the EMNIST dataset via the PyTorch class."""
- data_dir = Path(__file__).resolve().parents[3] / "data"
- logger.debug(f"Data directory is: {data_dir}")
- EMNIST(root=data_dir, split=split, download=True)
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
-def fetch_dataloader(
- root: str,
+def load_emnist_mapping() -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return dict(essentials["mapping"])
+
+
+def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(seed)
+ x = dataset.data
+ y = dataset.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_inds = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_inds.append(sampled_inds)
+ ind = np.concatenate(all_sampled_inds)
+ x_sampled = x[ind]
+ y_sampled = y[ind]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+
+def fetch_emnist_dataset(
split: str,
train: bool,
- download: bool,
- transform: Callable = None,
- target_transform: Callable = None,
+ sample_to_balance: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+) -> EMNIST:
+ """Fetch the EMNIST dataset."""
+ if transform is None:
+ transform = Compose([Transpose(), ToTensor()])
+
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=train,
+ download=False,
+ transform=transform,
+ target_transform=target_transform,
+ )
+
+ if sample_to_balance and split == "byclass":
+ _sample_to_balance(dataset)
+
+ return dataset
+
+
+def fetch_emnist_data_loader(
+ splits: List[str],
+ sample_to_balance: bool = False,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
batch_size: int = 128,
shuffle: bool = False,
num_workers: int = 0,
cuda: bool = True,
-) -> DataLoader:
- """Down/load the EMNIST dataset and return a PyTorch DataLoader.
+) -> Dict[DataLoader]:
+ """Fetches the EMNIST dataset and return a PyTorch DataLoader.
Args:
- root (str): Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt
- exist.
- split (str): The dataset has 6 different splits: byclass, bymerge, balanced, letters, digits and mnist.
- This argument specifies which one to use.
- train (bool): If True, creates dataset from training.pt, otherwise from test.pt.
- download (bool): If true, downloads the dataset from the internet and puts it in root directory. If
- dataset is already downloaded, it is not downloaded again.
- transform (Callable): A function/transform that takes in an PIL image and returns a transformed version.
- E.g, transforms.RandomCrop.
- target_transform (Callable): A function/transform that takes in the target and transforms it.
- batch_size (int): How many samples per batch to load (the default is 128).
- shuffle (bool): Set to True to have the data reshuffled at every epoch (the default is False).
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be loaded in
- the main process (default: 0).
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
+ Defaults to False.
+ transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
+ transformed version. E.g, transforms.RandomCrop. Defaults to None.
+ target_transform (Optional[Callable]): A function/transform that takes in the target and transforms
+ it.
+ Defaults to None.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
Returns:
- DataLoader: A PyTorch DataLoader with emnist characters.
+ Dict: A dict containing PyTorch DataLoader(s) with emnist characters.
"""
- dataset = EMNIST(
- root=root,
- split=split,
- train=train,
- download=download,
- transform=transform,
- target_transform=target_transform,
- )
+ data_loaders = {}
- data_loader = DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
+ for split in ["train", "val"]:
+ if split in splits:
+
+ if split == "train":
+ train = True
+ else:
+ train = False
+
+ dataset = fetch_emnist_dataset(
+ split, train, sample_to_balance, transform, target_transform
+ )
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ pin_memory=cuda,
+ )
+
+ data_loaders[split] = data_loader
- return data_loader
+ return data_loaders