From 9d8dcd2120dbe96c0e508b006a4b6e10b128f7a7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 3 Nov 2021 22:14:08 +0100 Subject: Add pad transform --- text_recognizer/data/transforms/pad.py | 20 ++++++++++++++++++++ .../conf/datamodule/iam_extended_paragraphs.yaml | 1 + 2 files changed, 21 insertions(+) create mode 100644 text_recognizer/data/transforms/pad.py diff --git a/text_recognizer/data/transforms/pad.py b/text_recognizer/data/transforms/pad.py new file mode 100644 index 0000000..82e4d54 --- /dev/null +++ b/text_recognizer/data/transforms/pad.py @@ -0,0 +1,20 @@ +"""Pad targets to equal length.""" + +import torch +from torch import Tensor +import torch.functional as F + + +class Pad: + """Pad target sequence.""" + + def __init__(self, max_len: int, pad_index: int) -> None: + self.max_len = max_len + self.pad_index = pad_index + + def __call__(self, y: Tensor) -> Tensor: + """Pads sequences with pad index if shorter than max len.""" + if y.shape[-1] < self.length: + pad_len = self.max_len - len(y) + y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len))) + return y[: self.max_len] diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index fd3ab50..f53e5b6 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -5,3 +5,4 @@ train_fraction: 0.8 pin_memory: false transform: transform/paragraphs.yaml test_transform: test_transform/paragraphs.yaml +target_transform: target_transform/pad.yaml -- cgit v1.2.3-70-g09d2