From 75909723fa2b1f6245d5c5422e4f2e88b8a26052 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 15 Nov 2020 17:40:44 +0100 Subject: Able to generate support files for lines datasets. --- src/text_recognizer/datasets/util.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'src/text_recognizer/datasets/util.py') diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index d2df8b5..bf5e772 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -12,7 +12,8 @@ import cv2 from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader, Dataset +import torch +from torch import Tensor from torchvision.datasets import EMNIST from tqdm import tqdm @@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -56,21 +57,21 @@ class EmnistMapper: self.eos_token = eos_token self.essentials = self._load_emnist_essentials() - # Load dataset infromation. + # Load dataset information. self._mapping = dict(self.essentials["mapping"]) self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: """Maps the token to emnist character or character index. If the token is an integer (index), the method will return the Emnist character corresponding to that index. If the token is a str (Emnist character), the method will return the corresponding index for that character. Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). Returns: Union[str, int]: The mapping result. @@ -79,9 +80,11 @@ class EmnistMapper: KeyError: If the index or string does not exist in the mapping. """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): return self.mapping[int(token)] elif isinstance(token, str) and token in self._inverse_mapping: return self._inverse_mapping[token] -- cgit v1.2.3-70-g09d2