summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/pad.py
blob: baf637a713c6d2b33f97e44366171b3ada4c52b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Pad targets to equal length."""

import torch
from torch import Tensor


class Pad:
    """Pad target sequence."""

    def __init__(self, max_len: int, pad_index: int) -> None:
        self.max_len = max_len
        self.pad_index = pad_index

    def __call__(self, y: Tensor) -> Tensor:
        """Pads sequences with pad index if shorter than max len."""
        if y.shape[-1] < self.max_len:
            pad_len = self.max_len - len(y)
            y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
        return y[: self.max_len]