summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
blob: c50afc385c81c117869b712ef8ffbfbb39d88764 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
from einops import repeat
import numpy as np
import torch
from torch import nn
from torch import Tensor


class PositionalEncoding(nn.Module):
    """Encodes a sense of distance or time for transformer networks."""

    def __init__(
        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
    ) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        pe = self.make_pe(hidden_dim, max_len)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
        """Returns positional encoding."""
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Encodes the tensor with a postional embedding."""
        # [T, B, D]
        if x.shape[2] != self.pe.shape[2]:
            raise ValueError(f"x shape does not match pe in the 3rd dim.")
        x = x + self.pe[: x.shape[0]]
        return self.dropout(x)


class PositionalEncoding2D(nn.Module):
    """Positional encodings for feature maps."""

    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
        super().__init__()
        if hidden_dim % 2 != 0:
            raise ValueError(f"Embedding depth {hidden_dim} is not even!")
        self.hidden_dim = hidden_dim
        pe = self.make_pe(hidden_dim, max_h, max_w)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
        """Returns 2d postional encoding."""
        pe_h = PositionalEncoding.make_pe(
            hidden_dim // 2, max_len=max_h
        )  # [H, 1, D // 2]
        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)

        pe_w = PositionalEncoding.make_pe(
            hidden_dim // 2, max_len=max_w
        )  # [W, 1, D // 2]
        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)

        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Adds 2D postional encoding to input tensor."""
        # Assumes x hase shape [B, D, H, W]
        if x.shape[1] != self.pe.shape[0]:
            raise ValueError("Hidden dimensions does not match.")
        x += self.pe[:, : x.shape[2], : x.shape[3]]
        return x


def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
    """Returns causal target mask."""
    trg_pad_mask = (trg != pad_index)[:, None, None]
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
    trg_mask = trg_pad_mask & trg_sub_mask
    return trg_mask