summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r--text_recognizer/data/transforms.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index d0f1f35..66531a5 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Optional, Union, Sequence
+import torch
from torch import Tensor
from text_recognizer.data.mappings import WordPieceMapping
@@ -20,7 +21,7 @@ class WordPiece:
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
extra_symbols: Optional[Sequence[str]] = ("\n",),
- max_len: int = 192,
+ max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
num_features,
@@ -35,4 +36,12 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len]
+ y = self.mapping.emnist_to_wordpiece_indices(x)
+ if len(y) < self.max_len:
+ pad_len = self.max_len - len(y)
+ y = torch.cat(
+ (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len))
+ )
+ else:
+ y = y[: self.max_len]
+ return y