diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 19:14:16 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-03 19:14:16 +0200 |
commit | 65d5f6c694e73792e40ed693a1381a792da8d277 (patch) | |
tree | c97067914c18b47108188c51294abd40bb1d6481 /text_recognizer/data/word_piece_mapping.py | |
parent | bd4bd443f339e95007bfdabf3e060db720f4d4b9 (diff) |
Fix bugs in converting text in mappings, add missing word_piece arg in datamodule
Diffstat (limited to 'text_recognizer/data/word_piece_mapping.py')
-rw-r--r-- | text_recognizer/data/word_piece_mapping.py | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 59488c3..2f650cd 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping): def get_text(self, indices: Union[List[int], Tensor]) -> str: if isinstance(indices, Tensor): indices = indices.tolist() - return self.wordpiece_processor.to_text(indices).replace(" ", "▁") + return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: return self.wordpiece_processor.to_index(text) @@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: - if isinstance(x, int): - x = [x] - if isinstance(x, str): - return self.get_indices(x) - return self.get_text(x) + def __getitem__(self, x: Union[int, Tensor]) -> str: + return self.get_token(x) |