diff options
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r-- | text_recognizer/data/transforms.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index f53df64..8d1bedd 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -4,7 +4,7 @@ from typing import Optional, Union, Sequence from torch import Tensor -from text_recognizer.datasets.mappings import WordPieceMapping +from text_recognizer.data.mappings import WordPieceMapping class WordPiece: @@ -12,14 +12,15 @@ class WordPiece: def __init__( self, - num_features: int, - tokens: str, - lexicon: str, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt" , + lexicon: str = "iamdb_1kwp_lex_1000.txt", data_dir: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Sequence[str]] = ("\n",), + max_len: int = 192, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -31,6 +32,7 @@ class WordPiece: special_tokens, extra_symbols, ) + self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x) + return self.mapping.emnist_to_wordpiece_indices(x)[:self.max_len] |