summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings.py')
-rw-r--r--text_recognizer/data/mappings.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index 190febe..0d778b2 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping):
special_tokens,
)
+ def __len__(self) -> int:
+ return len(self.wordpiece_processor.tokens)
+
def get_token(self, index: Union[int, Tensor]) -> str:
if (index := int(index)) <= self.wordpiece_processor.num_tokens:
return self.wordpiece_processor.tokens[index]
@@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping):
def get_index(self, token: str) -> Tensor:
if token in self.wordpiece_processor.tokens:
- return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token])
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
@@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping):
text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, str):
+ return self.get_index(x)
+ return self.get_token(x)