summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/util.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-15 17:40:44 +0100
commit75909723fa2b1f6245d5c5422e4f2e88b8a26052 (patch)
treee60c37d05c724db011d75adf9313d93839d193ac /src/text_recognizer/datasets/util.py
parentcad676fc423efeafde65f03e4815248f2d357011 (diff)
Able to generate support files for lines datasets.
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r--src/text_recognizer/datasets/util.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index d2df8b5..bf5e772 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -12,7 +12,8 @@ import cv2
from loguru import logger
import numpy as np
from PIL import Image
-from torch.utils.data import DataLoader, Dataset
+import torch
+from torch import Tensor
from torchvision.datasets import EMNIST
from tqdm import tqdm
@@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -56,21 +57,21 @@ class EmnistMapper:
self.eos_token = eos_token
self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
+ # Load dataset information.
self._mapping = dict(self.essentials["mapping"])
self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
"""Maps the token to emnist character or character index.
If the token is an integer (index), the method will return the Emnist character corresponding to that index.
If the token is a str (Emnist character), the method will return the corresponding index for that character.
Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
Returns:
Union[str, int]: The mapping result.
@@ -79,9 +80,11 @@ class EmnistMapper:
KeyError: If the index or string does not exist in the mapping.
"""
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
return self.mapping[int(token)]
elif isinstance(token, str) and token in self._inverse_mapping:
return self._inverse_mapping[token]