diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-27 20:25:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-27 20:25:25 +0200 |
commit | 442eac315e4b8be19adab80fb7332d29f68c077c (patch) | |
tree | 1084ff180e7852029918534343ec2a18b6b0485f /text_recognizer/data/transforms.py | |
parent | cafd6b8b10d804b3eee235652cb5218ef4a469b4 (diff) |
Fixed bug in word pieces
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r-- | text_recognizer/data/transforms.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index d0f1f35..66531a5 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, Union, Sequence +import torch from torch import Tensor from text_recognizer.data.mappings import WordPieceMapping @@ -20,7 +21,7 @@ class WordPiece: prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), extra_symbols: Optional[Sequence[str]] = ("\n",), - max_len: int = 192, + max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -35,4 +36,12 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] + y = self.mapping.emnist_to_wordpiece_indices(x) + if len(y) < self.max_len: + pad_len = self.max_len - len(y) + y = torch.cat( + (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len)) + ) + else: + y = y[: self.max_len] + return y |