diff options
Diffstat (limited to 'text_recognizer/data/word_piece_mapping.py')
-rw-r--r-- | text_recognizer/data/word_piece_mapping.py | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 59488c3..2f650cd 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping): def get_text(self, indices: Union[List[int], Tensor]) -> str: if isinstance(indices, Tensor): indices = indices.tolist() - return self.wordpiece_processor.to_text(indices).replace(" ", "▁") + return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: return self.wordpiece_processor.to_index(text) @@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: - if isinstance(x, int): - x = [x] - if isinstance(x, str): - return self.get_indices(x) - return self.get_text(x) + def __getitem__(self, x: Union[int, Tensor]) -> str: + return self.get_token(x) |