summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings/fourier.py
blob: ade589c4b527bac159bd55c5fc19d4fef7b185c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Fourier positional embedding."""
import numpy as np
import torch
from torch import nn
from torch import Tensor


class PositionalEncoding(nn.Module):
    """Encodes a sense of distance or time for transformer networks."""

    def __init__(
        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
    ) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        pe = self.make_pe(hidden_dim, max_len)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
        """Returns positional encoding."""
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Encodes the tensor with a postional embedding."""
        # [T, B, D]
        if x.shape[2] != self.pe.shape[2]:
            raise ValueError("x shape does not match pe in the 3rd dim.")
        x = x + self.pe[: x.shape[0]]
        return self.dropout(x)