summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r--src/text_recognizer/datasets/util.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 125f05a..d2df8b5 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -4,6 +4,7 @@ import importlib
import json
import os
from pathlib import Path
+import string
from typing import Callable, Dict, List, Optional, Type, Union
from urllib.request import urlopen, urlretrieve
@@ -43,11 +44,21 @@ def download_emnist() -> None:
class EmnistMapper:
"""Mapper between network output to Emnist character."""
- def __init__(self) -> None:
+ def __init__(
+ self,
+ pad_token: str,
+ init_token: Optional[str] = None,
+ eos_token: Optional[str] = None,
+ ) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+
self.essentials = self._load_emnist_essentials()
# Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._mapping = dict(self.essentials["mapping"])
+ self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
@@ -103,7 +114,7 @@ class EmnistMapper:
essentials = json.load(f)
return essentials
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
extra_symbols = [
@@ -127,14 +138,20 @@ class EmnistMapper:
]
# padding symbol, and acts as blank symbol as well.
- extra_symbols.append("_")
+ extra_symbols.append(self.pad_token)
+
+ if self.init_token is not None:
+ extra_symbols.append(self.init_token)
+
+ if self.eos_token is not None:
+ extra_symbols.append(self.eos_token)
- max_key = max(mapping.keys())
+ max_key = max(self.mapping.keys())
extra_mapping = {}
for i, symbol in enumerate(extra_symbols):
extra_mapping[max_key + 1 + i] = symbol
- return {**mapping, **extra_mapping}
+ self._mapping = {**self.mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str: