diff options
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r-- | text_recognizer/data/emnist_mapping.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 4406db7..12ac809 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -9,6 +9,8 @@ from text_recognizer.data.emnist import emnist_mapping class EmnistMapping(AbstractMapping): + """Mapping for EMNIST labels.""" + def __init__( self, extra_symbols: Optional[Set[str]] = None, lower: bool = True ) -> None: @@ -32,22 +34,27 @@ class EmnistMapping(AbstractMapping): self.mapping = [c for c in self.mapping if not c.isupper()] def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index value.""" if (index := int(index)) <= len(self.mapping): return self.mapping[index] raise KeyError(f"Index ({index}) not in mapping.") def get_index(self, token: str) -> Tensor: + """Returns index value of token.""" if token in self.inverse_mapping: return torch.LongTensor([self.inverse_mapping[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns the text from a list of indices.""" if isinstance(indices, Tensor): indices = indices.tolist() return "".join([self.mapping[index] for index in indices]) def get_indices(self, text: str) -> Tensor: + """Returns tensor of indices for a string.""" return Tensor([self.inverse_mapping[token] for token in text]) def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns text for a list of indices.""" return self.get_token(x) |