summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encoding.py
blob: 1ba5537b021ec8c13f5e128cc50f9974545a5f63 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
import numpy as np
import torch
from torch import nn
from torch import Tensor


class PositionalEncoding(nn.Module):
    """Encodes a sense of distance or time for transformer networks."""

    def __init__(
        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
    ) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.max_len = max_len

        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """Encodes the tensor with a postional embedding."""
        x = x + self.pe[:, : x.shape[1]]
        return self.dropout(x)