diff options
Diffstat (limited to 'text_recognizer/data/mappings.py')
| -rw-r--r-- | text_recognizer/data/mappings.py | 10 | 
1 files changed, 9 insertions, 1 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 190febe..0d778b2 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping):              special_tokens,          ) +    def __len__(self) -> int: +        return len(self.wordpiece_processor.tokens) +      def get_token(self, index: Union[int, Tensor]) -> str:          if (index := int(index)) <= self.wordpiece_processor.num_tokens:              return self.wordpiece_processor.tokens[index] @@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping):      def get_index(self, token: str) -> Tensor:          if token in self.wordpiece_processor.tokens: -            return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) +            return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])          raise KeyError(f"Token ({token}) not found in inverse mapping.")      def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping):          text = "".join([self.mapping[i] for i in x])          text = text.lower().replace(" ", "▁")          return torch.LongTensor(self.wordpiece_processor.to_index(text)) + +    def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: +        if isinstance(x, str): +            return self.get_index(x) +        return self.get_token(x)  |