diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:39:32 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:39:32 +0200 |
commit | 90090f4b66284ce616ea475dfb04e2ce0e4fdf64 (patch) | |
tree | 9d26d4d453f031e8409403ca9e9193abc478b6a6 /text_recognizer/data/mappings | |
parent | 38dc6ca3b787bcdb54d43ac5c076e08af25d44b2 (diff) |
WIP mappings
Diffstat (limited to 'text_recognizer/data/mappings')
-rw-r--r-- | text_recognizer/data/mappings/emnist.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py index 51e4677..ecd862e 100644 --- a/text_recognizer/data/mappings/emnist.py +++ b/text_recognizer/data/mappings/emnist.py @@ -1,12 +1,15 @@ """Emnist mapping.""" -from typing import List, Optional, Sequence, Union +import json +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Union, Tuple import torch from torch import Tensor -from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.mappings.base import AbstractMapping +ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json" + class EmnistMapping(AbstractMapping): """Mapping for EMNIST labels.""" @@ -15,13 +18,22 @@ class EmnistMapping(AbstractMapping): self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None - self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( - self.extra_symbols - ) + self.mapping, self.inverse_mapping, self.input_size = self._load_mapping() if lower: self._to_lower() super().__init__(self.input_size, self.mapping, self.inverse_mapping) + def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]: + """Return the EMNIST mapping.""" + with ESSENTIALS_FILENAME.open() as f: + essentials = json.load(f) + mapping = list(essentials["characters"]) + if self.extra_symbols is not None: + mapping += self.extra_symbols + inverse_mapping = {v: k for k, v in enumerate(mapping)} + input_shape = essentials["input_shape"] + return mapping, inverse_mapping, input_shape + def _to_lower(self) -> None: """Converts mapping to lowercase letters only.""" |