diff options
Diffstat (limited to 'text_recognizer/data/mappings.py')
-rw-r--r-- | text_recognizer/data/mappings.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 190febe..0d778b2 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping): special_tokens, ) + def __len__(self) -> int: + return len(self.wordpiece_processor.tokens) + def get_token(self, index: Union[int, Tensor]) -> str: if (index := int(index)) <= self.wordpiece_processor.num_tokens: return self.wordpiece_processor.tokens[index] @@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping): def get_index(self, token: str) -> Tensor: if token in self.wordpiece_processor.tokens: - return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) + return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping): text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) + + def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + if isinstance(x, str): + return self.get_index(x) + return self.get_token(x) |