summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r--src/text_recognizer/datasets/util.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index bf5e772..da87756 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,17 +1,14 @@
"""Util functions for datasets."""
import hashlib
-import importlib
import json
import os
from pathlib import Path
import string
-from typing import Callable, Dict, List, Optional, Type, Union
-from urllib.request import urlopen, urlretrieve
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
-import cv2
from loguru import logger
import numpy as np
-from PIL import Image
import torch
from torch import Tensor
from torchvision.datasets import EMNIST
@@ -50,11 +47,13 @@ class EmnistMapper:
pad_token: str,
init_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
+ self.lower = lower
self.essentials = self._load_emnist_essentials()
# Load dataset information.
@@ -120,6 +119,12 @@ class EmnistMapper:
def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
extra_symbols = [
" ",
"!",