diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-03 00:31:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-03 00:31:00 +0200 |
commit | 3a21c29e2eff4378c63717f8920ca3ccbfef013c (patch) | |
tree | ba46504d7baa8d4fb5bfd473acf99a7a184b330c /text_recognizer/data/word_piece_mapping.py | |
parent | 75eb34020620584247313926527019471411f6af (diff) |
Lint files
Diffstat (limited to 'text_recognizer/data/word_piece_mapping.py')
-rw-r--r-- | text_recognizer/data/word_piece_mapping.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 2f650cd..dc56942 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -1,9 +1,9 @@ """Word piece mapping.""" from pathlib import Path -from typing import List, Optional, Union, Set +from typing import List, Optional, Set, Union -import torch from loguru import logger as log +import torch from torch import Tensor from text_recognizer.data.emnist_mapping import EmnistMapping @@ -11,6 +11,8 @@ from text_recognizer.data.iam_preprocessor import Preprocessor class WordPieceMapping(EmnistMapping): + """Word piece mapping.""" + def __init__( self, data_dir: Optional[Path] = None, @@ -20,7 +22,7 @@ class WordPieceMapping(EmnistMapping): use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Set[str] = {"<s>", "<e>", "<p>"}, - extra_symbols: Set[str] = {"\n",}, + extra_symbols: Set[str] = {"\n"}, ) -> None: super().__init__(extra_symbols=extra_symbols) self.data_dir = ( @@ -60,30 +62,37 @@ class WordPieceMapping(EmnistMapping): ) def __len__(self) -> int: + """Return number of word pieces.""" return len(self.wordpiece_processor.tokens) def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index.""" if (index := int(index)) <= self.wordpiece_processor.num_tokens: return self.wordpiece_processor.tokens[index] raise KeyError(f"Index ({index}) not in mapping.") def get_index(self, token: str) -> Tensor: + """Returns index of token.""" if token in self.wordpiece_processor.tokens: return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns text from indices.""" if isinstance(indices, Tensor): indices = indices.tolist() return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: + """Returns indices of text.""" return self.wordpiece_processor.to_index(text) def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: + """Returns word pieces indices from emnist indices.""" text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns token for word piece index.""" return self.get_token(x) |