blob: baf637a713c6d2b33f97e44366171b3ada4c52b2 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
"""Pad targets to equal length."""
import torch
from torch import Tensor
class Pad:
"""Pad target sequence."""
def __init__(self, max_len: int, pad_index: int) -> None:
self.max_len = max_len
self.pad_index = pad_index
def __call__(self, y: Tensor) -> Tensor:
"""Pads sequences with pad index if shorter than max len."""
if y.shape[-1] < self.max_len:
pad_len = self.max_len - len(y)
y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
return y[: self.max_len]
|