blob: 1ba5537b021ec8c13f5e128cc50f9974545a5f63 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
import numpy as np
import torch
from torch import nn
from torch import Tensor
class PositionalEncoding(nn.Module):
"""Encodes a sense of distance or time for transformer networks."""
def __init__(
self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
) -> None:
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
self.max_len = max_len
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
x = x + self.pe[:, : x.shape[1]]
return self.dropout(x)
|