summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:14:08 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:14:08 +0100
commit9d8dcd2120dbe96c0e508b006a4b6e10b128f7a7 (patch)
tree6c5ce392fa96581d87652d4835c268441f34db6b /text_recognizer
parent9381895e6f0154b0e9acc9e540266367e8a35843 (diff)
Add pad transform
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/transforms/pad.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/pad.py b/text_recognizer/data/transforms/pad.py
new file mode 100644
index 0000000..82e4d54
--- /dev/null
+++ b/text_recognizer/data/transforms/pad.py
@@ -0,0 +1,20 @@
+"""Pad targets to equal length."""
+
+import torch
+from torch import Tensor
+import torch.functional as F
+
+
+class Pad:
+ """Pad target sequence."""
+
+ def __init__(self, max_len: int, pad_index: int) -> None:
+ self.max_len = max_len
+ self.pad_index = pad_index
+
+ def __call__(self, y: Tensor) -> Tensor:
+ """Pads sequences with pad index if shorter than max len."""
+ if y.shape[-1] < self.length:
+ pad_len = self.max_len - len(y)
+ y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
+ return y[: self.max_len]