summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py228
1 files changed, 31 insertions, 197 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 0715aae..81268fb 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -2,139 +2,26 @@
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Callable, Optional, Tuple, Union
from loguru import logger
import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torchvision.transforms import Compose, ToTensor
-from text_recognizer.datasets.util import Transpose
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import DATA_DIRNAME
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(self) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
class EmnistDataset(Dataset):
@@ -159,70 +46,33 @@ class EmnistDataset(Dataset):
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
seed (int): Seed number. Defaults to 4711.
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
"""
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
- self.train = train
self.sample_to_balance = sample_to_balance
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self.transform = transform
- if self.transform is None:
+ # Have to transpose the emnist characters, ToTensor norms input between [0,1].
+ if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
+ # The EMNIST dataset is already casted to tensors.
self.target_transform = target_transform
- self.seed = seed
-
- self._mapper = EmnistMapper()
- self._input_shape = self._mapper.input_shape
- self.num_classes = self._mapper.num_classes
-
- # Load dataset.
- self._data, self._targets = self.load_emnist_dataset()
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+ index (Union[int, Tensor]): The indices of the samples to fetch.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target tuple.
+ Tuple[Tensor, Tensor]: Data target tuple.
"""
if torch.is_tensor(index):
@@ -248,13 +98,11 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(
- self, data: Tensor, targets: Tensor
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def _sample_to_balance(self) -> None:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = data
- y = targets
+ x = self._data
+ y = self._targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -264,22 +112,10 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- data = x_sampled
- targets = y_sampled
- return data, targets
-
- def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
- """Subsamples the dataset to the specified fraction."""
- x = data
- y = targets
- num_samples = int(x.shape[0] * self.subsample_fraction)
- x_sampled = x[:num_samples]
- y_sampled = y[:num_samples]
- self.data = x_sampled
- self.targets = y_sampled
- return data, targets
+ self._data = x_sampled
+ self._targets = y_sampled
- def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
+ def load_or_generate_data(self) -> None:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -290,13 +126,11 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- data = dataset.data
- targets = dataset.targets
+ self._data = dataset.data
+ self._targets = dataset.targets
if self.sample_to_balance:
- data, targets = self._sample_to_balance(data, targets)
+ self._sample_to_balance()
if self.subsample_fraction is not None:
- data, targets = self._subsample(data, targets)
-
- return data, targets
+ self._subsample()