blob: 82e4d548d08573607a7ab16deaa550e1a0e61904 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
"""Pad targets to equal length."""
import torch
from torch import Tensor
import torch.functional as F
class Pad:
"""Pad target sequence."""
def __init__(self, max_len: int, pad_index: int) -> None:
self.max_len = max_len
self.pad_index = pad_index
def __call__(self, y: Tensor) -> Tensor:
"""Pads sequences with pad index if shorter than max len."""
if y.shape[-1] < self.length:
pad_len = self.max_len - len(y)
y = torch.cat((y, torch.LongTensor([self.pad_index] * pad_len)))
return y[: self.max_len]
|