diff options
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r-- | text_recognizer/data/emnist_mapping.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py new file mode 100644 index 0000000..6c4c43b --- /dev/null +++ b/text_recognizer/data/emnist_mapping.py @@ -0,0 +1,37 @@ +"""Emnist mapping.""" +from typing import List, Optional, Union, Set + +from torch import Tensor + +from text_recognizer.data.base_mapping import AbstractMapping +from text_recognizer.data.emnist import emnist_mapping + + +class EmnistMapping(AbstractMapping): + def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None: + self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None + self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( + self.extra_symbols + ) + super().__init__(self.input_size, self.mapping, self.inverse_mapping) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + + def get_token(self, index: Union[int, Tensor]) -> str: + if (index := int(index)) in self.mapping: + return self.mapping[index] + raise KeyError(f"Index ({index}) not in mapping.") + + def get_index(self, token: str) -> Tensor: + if token in self.inverse_mapping: + return Tensor(self.inverse_mapping[token]) + raise KeyError(f"Token ({token}) not found in inverse mapping.") + + def get_text(self, indices: Union[List[int], Tensor]) -> str: + if isinstance(indices, Tensor): + indices = indices.tolist() + return "".join([self.mapping[index] for index in indices]) + + def get_indices(self, text: str) -> Tensor: + return Tensor([self.inverse_mapping[token] for token in text]) |