diff options
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r-- | src/text_recognizer/datasets/util.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index d2df8b5..bf5e772 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -12,7 +12,8 @@ import cv2 from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader, Dataset +import torch +from torch import Tensor from torchvision.datasets import EMNIST from tqdm import tqdm @@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -56,21 +57,21 @@ class EmnistMapper: self.eos_token = eos_token self.essentials = self._load_emnist_essentials() - # Load dataset infromation. + # Load dataset information. self._mapping = dict(self.essentials["mapping"]) self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: """Maps the token to emnist character or character index. If the token is an integer (index), the method will return the Emnist character corresponding to that index. If the token is a str (Emnist character), the method will return the corresponding index for that character. Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). Returns: Union[str, int]: The mapping result. @@ -79,9 +80,11 @@ class EmnistMapper: KeyError: If the index or string does not exist in the mapping. """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): return self.mapping[int(token)] elif isinstance(token, str) and token in self._inverse_mapping: return self._inverse_mapping[token] |