diff options
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r-- | text_recognizer/data/transforms.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index d0f1f35..66531a5 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, Union, Sequence +import torch from torch import Tensor from text_recognizer.data.mappings import WordPieceMapping @@ -20,7 +21,7 @@ class WordPiece: prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), extra_symbols: Optional[Sequence[str]] = ("\n",), - max_len: int = 192, + max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -35,4 +36,12 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] + y = self.mapping.emnist_to_wordpiece_indices(x) + if len(y) < self.max_len: + pad_len = self.max_len - len(y) + y = torch.cat( + (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len)) + ) + else: + y = y[: self.max_len] + return y |