diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 14:54:44 +0100 |
commit | dc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch) | |
tree | 1b5fc0d06952e13727e85c4f973a26d277068453 /src/text_recognizer/datasets/util.py | |
parent | e181195a699d7fa237f256d90ab4dedffc03d405 (diff) |
new updates
Diffstat (limited to 'src/text_recognizer/datasets/util.py')
-rw-r--r-- | src/text_recognizer/datasets/util.py | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 73968a1..d2df8b5 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -4,6 +4,7 @@ import importlib import json import os from pathlib import Path +import string from typing import Callable, Dict, List, Optional, Type, Union from urllib.request import urlopen, urlretrieve @@ -26,7 +27,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: mapping = [(i, str(label)) for i, label in enumerate(labels)] essentials = { "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), } logger.info("Saving emnist essentials...") with open(ESSENTIALS_FILENAME, "w") as f: @@ -43,11 +44,21 @@ def download_emnist() -> None: class EmnistMapper: """Mapper between network output to Emnist character.""" - def __init__(self) -> None: + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.essentials = self._load_emnist_essentials() # Load dataset infromation. - self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] @@ -103,7 +114,7 @@ class EmnistMapper: essentials = json.load(f) return essentials - def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset extra_symbols = [ @@ -127,14 +138,20 @@ class EmnistMapper: ] # padding symbol, and acts as blank symbol as well. - extra_symbols.append("_") + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) - max_key = max(mapping.keys()) + max_key = max(self.mapping.keys()) extra_mapping = {} for i, symbol in enumerate(extra_symbols): extra_mapping[max_key + 1 + i] = symbol - return {**mapping, **extra_mapping} + self._mapping = {**self.mapping, **extra_mapping} def compute_sha256(filename: Union[Path, str]) -> str: |