summaryrefslogtreecommitdiff
path: root/text_recognizer/data/word_piece_mapping.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 19:14:16 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 19:14:16 +0200
commit65d5f6c694e73792e40ed693a1381a792da8d277 (patch)
treec97067914c18b47108188c51294abd40bb1d6481 /text_recognizer/data/word_piece_mapping.py
parentbd4bd443f339e95007bfdabf3e060db720f4d4b9 (diff)
Fix bugs in converting text in mappings, add missing word_piece arg in datamodule
Diffstat (limited to 'text_recognizer/data/word_piece_mapping.py')
-rw-r--r--text_recognizer/data/word_piece_mapping.py10
1 files changed, 3 insertions, 7 deletions
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
index 59488c3..2f650cd 100644
--- a/text_recognizer/data/word_piece_mapping.py
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping):
def get_text(self, indices: Union[List[int], Tensor]) -> str:
if isinstance(indices, Tensor):
indices = indices.tolist()
- return self.wordpiece_processor.to_text(indices).replace(" ", "▁")
+ return self.wordpiece_processor.to_text(indices)
def get_indices(self, text: str) -> Tensor:
return self.wordpiece_processor.to_index(text)
@@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping):
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
- def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
- if isinstance(x, int):
- x = [x]
- if isinstance(x, str):
- return self.get_indices(x)
- return self.get_text(x)
+ def __getitem__(self, x: Union[int, Tensor]) -> str:
+ return self.get_token(x)