blob: 1da45347707da5f717b085f3479c2be2e51e6d6d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
"""Pad targets to equal length."""
import torch
from torch import Tensor
import torch.functional as F
class Pad:
"""Pad target sequence."""
def __init__(self, max_len: int, pad_index: int) -> None:
self.max_len = max_len
self.pad_index = pad_index
def __call__(self, y: Tensor) -> Tensor:
"""Pads sequences with pad index if shorter than max len."""
if y.shape[-1] < self.max_len:
pad_len = self.max_len - len(y)
y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
return y[: self.max_len]
|