diff options
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r-- | src/text_recognizer/datasets/util.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index bf5e772..da87756 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,17 +1,14 @@ """Util functions for datasets.""" import hashlib -import importlib import json import os from pathlib import Path import string -from typing import Callable, Dict, List, Optional, Type, Union -from urllib.request import urlopen, urlretrieve +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve -import cv2 from loguru import logger import numpy as np -from PIL import Image import torch from torch import Tensor from torchvision.datasets import EMNIST @@ -50,11 +47,13 @@ class EmnistMapper: pad_token: str, init_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token + self.lower = lower self.essentials = self._load_emnist_essentials() # Load dataset information. @@ -120,6 +119,12 @@ class EmnistMapper: def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + extra_symbols = [ " ", "!", |