summaryrefslogtreecommitdiff
path: root/text_recognizer/data/word_piece_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/word_piece_mapping.py')
-rw-r--r--text_recognizer/data/word_piece_mapping.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
index 2f650cd..dc56942 100644
--- a/text_recognizer/data/word_piece_mapping.py
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -1,9 +1,9 @@
"""Word piece mapping."""
from pathlib import Path
-from typing import List, Optional, Union, Set
+from typing import List, Optional, Set, Union
-import torch
from loguru import logger as log
+import torch
from torch import Tensor
from text_recognizer.data.emnist_mapping import EmnistMapping
@@ -11,6 +11,8 @@ from text_recognizer.data.iam_preprocessor import Preprocessor
class WordPieceMapping(EmnistMapping):
+ """Word piece mapping."""
+
def __init__(
self,
data_dir: Optional[Path] = None,
@@ -20,7 +22,7 @@ class WordPieceMapping(EmnistMapping):
use_words: bool = False,
prepend_wordsep: bool = False,
special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
- extra_symbols: Set[str] = {"\n",},
+ extra_symbols: Set[str] = {"\n"},
) -> None:
super().__init__(extra_symbols=extra_symbols)
self.data_dir = (
@@ -60,30 +62,37 @@ class WordPieceMapping(EmnistMapping):
)
def __len__(self) -> int:
+ """Return number of word pieces."""
return len(self.wordpiece_processor.tokens)
def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index."""
if (index := int(index)) <= self.wordpiece_processor.num_tokens:
return self.wordpiece_processor.tokens[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_index(self, token: str) -> Tensor:
+ """Returns index of token."""
if token in self.wordpiece_processor.tokens:
return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns text from indices."""
if isinstance(indices, Tensor):
indices = indices.tolist()
return self.wordpiece_processor.to_text(indices)
def get_indices(self, text: str) -> Tensor:
+ """Returns indices of text."""
return self.wordpiece_processor.to_index(text)
def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ """Returns word pieces indices from emnist indices."""
text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns token for word piece index."""
return self.get_token(x)