summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings/emnist.py')
-rw-r--r--text_recognizer/data/mappings/emnist.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/text_recognizer/data/mappings/emnist.py b/text_recognizer/data/mappings/emnist.py
index 51e4677..ecd862e 100644
--- a/text_recognizer/data/mappings/emnist.py
+++ b/text_recognizer/data/mappings/emnist.py
@@ -1,12 +1,15 @@
"""Emnist mapping."""
-from typing import List, Optional, Sequence, Union
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Union, Tuple
import torch
from torch import Tensor
-from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.mappings.base import AbstractMapping
+ESSENTIALS_FILENAME = Path(__file__).parents[0].resolve() / "emnist_essentials.json"
+
class EmnistMapping(AbstractMapping):
"""Mapping for EMNIST labels."""
@@ -15,13 +18,22 @@ class EmnistMapping(AbstractMapping):
self, extra_symbols: Optional[Sequence[str]] = None, lower: bool = True
) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
- self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- self.extra_symbols
- )
+ self.mapping, self.inverse_mapping, self.input_size = self._load_mapping()
if lower:
self._to_lower()
super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+ def _load_mapping(self) -> Tuple[List, Dict[str, int], List[int]]:
+ """Return the EMNIST mapping."""
+ with ESSENTIALS_FILENAME.open() as f:
+ essentials = json.load(f)
+ mapping = list(essentials["characters"])
+ if self.extra_symbols is not None:
+ mapping += self.extra_symbols
+ inverse_mapping = {v: k for k, v in enumerate(mapping)}
+ input_shape = essentials["input_shape"]
+ return mapping, inverse_mapping, input_shape
+
def _to_lower(self) -> None:
"""Converts mapping to lowercase letters only."""