From 9d8dcd2120dbe96c0e508b006a4b6e10b128f7a7 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 3 Nov 2021 22:14:08 +0100
Subject: Add pad transform

---
 text_recognizer/data/transforms/pad.py               | 20 ++++++++++++++++++++
 .../conf/datamodule/iam_extended_paragraphs.yaml     |  1 +
 2 files changed, 21 insertions(+)
 create mode 100644 text_recognizer/data/transforms/pad.py

diff --git a/text_recognizer/data/transforms/pad.py b/text_recognizer/data/transforms/pad.py
new file mode 100644
index 0000000..82e4d54
--- /dev/null
+++ b/text_recognizer/data/transforms/pad.py
@@ -0,0 +1,20 @@
+"""Pad targets to equal length."""
+
+import torch
+from torch import Tensor
+import torch.functional as F
+
+
+class Pad:
+    """Pad target sequence."""
+
+    def __init__(self, max_len: int, pad_index: int) -> None:
+        self.max_len = max_len
+        self.pad_index = pad_index
+
+    def __call__(self, y: Tensor) -> Tensor:
+        """Pads sequences with pad index if shorter than max len."""
+        if y.shape[-1] < self.length:
+            pad_len = self.max_len - len(y)
+            y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
+        return y[: self.max_len]
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index fd3ab50..f53e5b6 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -5,3 +5,4 @@ train_fraction: 0.8
 pin_memory: false
 transform: transform/paragraphs.yaml
 test_transform: test_transform/paragraphs.yaml
+target_transform: target_transform/pad.yaml
-- 
cgit v1.2.3-70-g09d2