summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r--text_recognizer/data/emnist_mapping.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 4406db7..12ac809 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -9,6 +9,8 @@ from text_recognizer.data.emnist import emnist_mapping
class EmnistMapping(AbstractMapping):
+ """Mapping for EMNIST labels."""
+
def __init__(
self, extra_symbols: Optional[Set[str]] = None, lower: bool = True
) -> None:
@@ -32,22 +34,27 @@ class EmnistMapping(AbstractMapping):
self.mapping = [c for c in self.mapping if not c.isupper()]
def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index value."""
if (index := int(index)) <= len(self.mapping):
return self.mapping[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_index(self, token: str) -> Tensor:
+ """Returns index value of token."""
if token in self.inverse_mapping:
return torch.LongTensor([self.inverse_mapping[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns the text from a list of indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
return "".join([self.mapping[index] for index in indices])
def get_indices(self, text: str) -> Tensor:
+ """Returns tensor of indices for a string."""
return Tensor([self.inverse_mapping[token] for token in text])
def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns text for a list of indices."""
return self.get_token(x)