summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_mapping.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /text_recognizer/data/emnist_mapping.py
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r--text_recognizer/data/emnist_mapping.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
new file mode 100644
index 0000000..6c4c43b
--- /dev/null
+++ b/text_recognizer/data/emnist_mapping.py
@@ -0,0 +1,37 @@
+"""Emnist mapping."""
+from typing import List, Optional, Union, Set
+
+from torch import Tensor
+
+from text_recognizer.data.base_mapping import AbstractMapping
+from text_recognizer.data.emnist import emnist_mapping
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None:
+ self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ self.extra_symbols
+ )
+ super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])