1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
from einops import repeat
import numpy as np
import torch
from torch import nn
from torch import Tensor
class PositionalEncoding(nn.Module):
"""Encodes a sense of distance or time for transformer networks."""
def __init__(
self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
pe = self.make_pe(hidden_dim, max_len)
self.register_buffer("pe", pe)
@staticmethod
def make_pe(hidden_dim: int, max_len: int) -> Tensor:
"""Returns positional encoding."""
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
# [T, B, D]
if x.shape[2] != self.pe.shape[2]:
raise ValueError(f"x shape does not match pe in the 3rd dim.")
x = x + self.pe[: x.shape[0]]
return self.dropout(x)
class PositionalEncoding2D(nn.Module):
"""Positional encodings for feature maps."""
def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
super().__init__()
if hidden_dim % 2 != 0:
raise ValueError(f"Embedding depth {hidden_dim} is not even!")
self.hidden_dim = hidden_dim
pe = self.make_pe(hidden_dim, max_h, max_w)
self.register_buffer("pe", pe)
@staticmethod
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
pe_h = PositionalEncoding.make_pe(
hidden_dim // 2, max_len=max_h
) # [H, 1, D // 2]
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
pe_w = PositionalEncoding.make_pe(
hidden_dim // 2, max_len=max_w
) # [W, 1, D // 2]
pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
return pe
def forward(self, x: Tensor) -> Tensor:
"""Adds 2D postional encoding to input tensor."""
# Assumes x hase shape [B, D, H, W]
if x.shape[1] != self.pe.shape[0]:
raise ValueError("Hidden dimensions does not match.")
x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
"""Returns causal target mask."""
trg_pad_mask = (trg != pad_index)[:, None, None]
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask
|