diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 | 
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 | 
| commit | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (patch) | |
| tree | 1c0e0898cb8b66faff9e5d410aa1f82d13542f68 /src/text_recognizer/datasets/util.py | |
| parent | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (diff) | |
Created an abstract Dataset class for common methods.
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
| -rw-r--r-- | src/text_recognizer/datasets/util.py | 125 | 
1 files changed, 120 insertions, 5 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index dd16bed..3acf5db 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,6 +1,7 @@  """Util functions for datasets."""  import hashlib  import importlib +import json  import os  from pathlib import Path  from typing import Callable, Dict, List, Optional, Type, Union @@ -11,15 +12,129 @@ from loguru import logger  import numpy as np  from PIL import Image  from torch.utils.data import DataLoader, Dataset +from torchvision.datasets import EMNIST  from tqdm import tqdm +DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" +ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -class Transpose: -    """Transposes the EMNIST image to the correct orientation.""" -    def __call__(self, image: Image) -> np.ndarray: -        """Swaps axis.""" -        return np.array(image).swapaxes(0, 1) +def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +    """Extract and saves EMNIST essentials.""" +    labels = emnsit_dataset.classes +    labels.sort() +    mapping = [(i, str(label)) for i, label in enumerate(labels)] +    essentials = { +        "mapping": mapping, +        "input_shape": tuple(emnsit_dataset[0][0].shape[:]), +    } +    logger.info("Saving emnist essentials...") +    with open(ESSENTIALS_FILENAME, "w") as f: +        json.dump(essentials, f) + + +def download_emnist() -> None: +    """Download the EMNIST dataset via the PyTorch class.""" +    logger.info(f"Data directory is: {DATA_DIRNAME}") +    dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True) +    save_emnist_essentials(dataset) + + +class EmnistMapper: +    """Mapper between network output to Emnist character.""" + +    def __init__(self) -> None: +        """Loads the emnist essentials file with the mapping and input shape.""" +        self.essentials = self._load_emnist_essentials() +        # Load dataset infromation. +        self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) +        self._inverse_mapping = {v: k for k, v in self.mapping.items()} +        self._num_classes = len(self.mapping) +        self._input_shape = self.essentials["input_shape"] + +    def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: +        """Maps the token to emnist character or character index. + +        If the token is an integer (index), the method will return the Emnist character corresponding to that index. +        If the token is a str (Emnist character), the method will return the corresponding index for that character. + +        Args: +            token (Union[str, int, np.uint8]): Eihter a string or index (integer). + +        Returns: +            Union[str, int]: The mapping result. + +        Raises: +            KeyError: If the index or string does not exist in the mapping. + +        """ +        if (isinstance(token, np.uint8) or isinstance(token, int)) and int( +            token +        ) in self.mapping: +            return self.mapping[int(token)] +        elif isinstance(token, str) and token in self._inverse_mapping: +            return self._inverse_mapping[token] +        else: +            raise KeyError(f"Token {token} does not exist in the mappings.") + +    @property +    def mapping(self) -> Dict: +        """Returns the mapping between index and character.""" +        return self._mapping + +    @property +    def inverse_mapping(self) -> Dict: +        """Returns the mapping between character and index.""" +        return self._inverse_mapping + +    @property +    def num_classes(self) -> int: +        """Returns the number of classes in the dataset.""" +        return self._num_classes + +    @property +    def input_shape(self) -> List[int]: +        """Returns the input shape of the Emnist characters.""" +        return self._input_shape + +    def _load_emnist_essentials(self) -> Dict: +        """Load the EMNIST mapping.""" +        with open(str(ESSENTIALS_FILENAME)) as f: +            essentials = json.load(f) +        return essentials + +    def _augment_emnist_mapping(self, mapping: Dict) -> Dict: +        """Augment the mapping with extra symbols.""" +        # Extra symbols in IAM dataset +        extra_symbols = [ +            " ", +            "!", +            '"', +            "#", +            "&", +            "'", +            "(", +            ")", +            "*", +            "+", +            ",", +            "-", +            ".", +            "/", +            ":", +            ";", +            "?", +        ] + +        # padding symbol +        extra_symbols.append("_") + +        max_key = max(mapping.keys()) +        extra_mapping = {} +        for i, symbol in enumerate(extra_symbols): +            extra_mapping[max_key + 1 + i] = symbol + +        return {**mapping, **extra_mapping}  def compute_sha256(filename: Union[Path, str]) -> str:  |