summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/data_loader.py15
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py258
3 files changed, 156 insertions, 119 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 929cfb9..aec5bf9 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,2 @@
"""Dataset modules."""
-from .data_loader import fetch_data_loader
+from .emnist_dataset import EmnistDataLoader
diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py
deleted file mode 100644
index fd55934..0000000
--- a/src/text_recognizer/datasets/data_loader.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Data loader collection."""
-
-from typing import Dict
-
-from torch.utils.data import DataLoader
-
-from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader
-
-
-def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
- """Fetches the specified PyTorch data loader."""
- if data_loader_args.pop("name").lower() == "emnist":
- return fetch_emnist_data_loader(data_loader_args)
- else:
- raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f9c8ffa..a17d7a9 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -24,7 +24,7 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def save_emnist_essentials(emnsit_dataset: EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -45,111 +45,163 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def load_emnist_mapping() -> Dict:
+def load_emnist_mapping() -> Dict[int, str]:
"""Load the EMNIST mapping."""
with open(str(ESSENTIALS_FILENAME)) as f:
essentials = json.load(f)
return dict(essentials["mapping"])
-def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(seed)
- x = dataset.data
- y = dataset.targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_inds = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_inds.append(sampled_inds)
- ind = np.concatenate(all_sampled_inds)
- x_sampled = x[ind]
- y_sampled = y[ind]
- dataset.data = x_sampled
- dataset.targets = y_sampled
-
-
-def fetch_emnist_dataset(
- split: str,
- train: bool,
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
-) -> EMNIST:
- """Fetch the EMNIST dataset."""
- if transform is None:
- transform = Compose([Transpose(), ToTensor()])
-
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=train,
- download=False,
- transform=transform,
- target_transform=target_transform,
- )
-
- if sample_to_balance and split == "byclass":
- _sample_to_balance(dataset)
-
- return dataset
-
-
-def fetch_emnist_data_loader(
- splits: List[str],
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[DataLoader]:
- """Fetches the EMNIST dataset and return a PyTorch DataLoader.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and transforms
- it.
- Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict: A dict containing PyTorch DataLoader(s) with emnist characters.
-
- """
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in splits:
-
- if split == "train":
- train = True
- else:
- train = False
-
- dataset = fetch_emnist_dataset(
- split, train, sample_to_balance, transform, target_transform
- )
-
- data_loader = DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
+class EmnistDataLoader:
+ """Class for Emnist DataLoaders."""
+
+ def __init__(
+ self,
+ splits: List[str],
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+ seed: int = 4711,
+ ) -> None:
+ """Fetches DataLoaders.
+
+ Args:
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
+ Defaults to False.
+ subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
+ dataset will be loaded.
+ transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
+ transformed version. E.g, transforms.RandomCrop. Defaults to None.
+ target_transform (Optional[Callable]): A function/transform that takes in the target and
+ transforms it. Defaults to None.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
+ seed (int): Seed for sampling.
+
+ """
+ self.splits = splits
+ self.sample_to_balance = sample_to_balance
+ if subsample_fraction is not None:
+ assert (
+ 0.0 < subsample_fraction < 1.0
+ ), " The subsample fraction must be in (0, 1)."
+ self.subsample_fraction = subsample_fraction
+ self.transform = transform
+ self.target_transform = target_transform
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.cuda = cuda
+ self._data_loaders = self._fetch_emnist_data_loaders()
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EMNIST"
+
+ def __call__(self, split: str) -> Optional[DataLoader]:
+ """Returns the `split` DataLoader.
+
+ Args:
+ split (str): The dataset split, i.e. train or val.
+
+ Returns:
+ type: A PyTorch DataLoader.
+
+ Raises:
+ ValueError: If the split does not exist.
+
+ """
+ try:
+ return self._data_loaders[split]
+ except KeyError:
+ raise ValueError(f"Split {split} does not exist.")
+
+ def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = dataset.data
+ y = dataset.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _subsample(self, dataset: type = EMNIST) -> EMNIST:
+ """Subsamples the dataset to the specified fraction."""
+ x = dataset.data
+ y = dataset.targets
+ num_samples = int(x.shape[0] * self.subsample_fraction)
+ x_sampled = x[:num_samples]
+ y_sampled = y[:num_samples]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _fetch_emnist_dataset(self, train: bool) -> EMNIST:
+ """Fetch the EMNIST dataset."""
+ if self.transform is None:
+ transform = Compose([Transpose(), ToTensor()])
+
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=train,
+ download=False,
+ transform=transform,
+ target_transform=self.target_transform,
+ )
+
+ if self.sample_to_balance:
+ dataset = self._sample_to_balance(dataset)
+
+ if self.subsample_fraction is not None:
+ dataset = self._subsample(dataset)
+
+ return dataset
+
+ def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
+ """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in self.splits:
+
+ if split == "train":
+ train = True
+ else:
+ train = False
+
+ dataset = self._fetch_emnist_dataset(train)
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ pin_memory=self.cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders