summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py275
1 files changed, 114 insertions, 161 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f3d65ee..96f84e5 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -39,45 +39,101 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def _load_emnist_essentials() -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
-
-def _augment_emnist_mapping(mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(self) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset infromation.
+ self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
+ token
+ ) in self.mapping:
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
class EmnistDataset(Dataset):
@@ -110,10 +166,12 @@ class EmnistDataset(Dataset):
self.train = train
self.sample_to_balance = sample_to_balance
+
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
+
self.transform = transform
if self.transform is None:
self.transform = Compose([Transpose(), ToTensor()])
@@ -121,17 +179,22 @@ class EmnistDataset(Dataset):
self.target_transform = target_transform
self.seed = seed
- # Load dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.inverse_mapping = {v: k for k, v in self.mapping.items()}
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
# Placeholders
self.data = None
self.targets = None
+ # Load dataset.
+ self.load_emnist_dataset()
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -162,13 +225,18 @@ class EmnistDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
"EMNIST Dataset\n"
f"Num classes: {self.num_classes}\n"
- f"Mapping: {self.mapping}\n"
f"Input shape: {self.input_shape}\n"
+ f"Mapping: {self.mapper.mapping}\n"
)
def _sample_to_balance(self) -> None:
@@ -217,118 +285,3 @@ class EmnistDataset(Dataset):
if self.subsample_fraction is not None:
self._subsample()
-
-
-class EmnistDataLoaders:
- """Class for Emnist DataLoaders."""
-
- def __init__(
- self,
- splits: List[str],
- sample_to_balance: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Fetches DataLoaders for given split(s).
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
- dataset will be loaded.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and
- transforms it. Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
- seed (int): Seed for sampling.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.splits = splits
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
-
- self.dataset_args = {
- "sample_to_balance": sample_to_balance,
- "subsample_fraction": subsample_fraction,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "Emnist"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_dataset = EmnistDataset(**self.dataset_args)
-
- emnist_dataset.load_emnist_dataset()
-
- data_loader = DataLoader(
- dataset=emnist_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders