summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
blob: 85f513e1dc4e86330f97995776d4821599fc7776 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Implementes the attention module for the transformer."""
from typing import Optional

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, einsum, nn

from text_recognizer.networks.transformer.embeddings.rotary import (
    RotaryEmbedding,
)


class Attention(nn.Module):
    """Standard attention."""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        causal: bool = False,
        dim_head: int = 64,
        dropout_rate: float = 0.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.scale = self.dim**-0.5
        self.num_heads = num_heads
        self.dim_head = dim_head

        self.causal = causal
        self.dropout_rate = dropout_rate

        # Single key/value head
        k_dim = dim_head
        v_dim = dim_head

        out_dim = self.num_heads * self.dim_head

        self.to_q = nn.Linear(self.dim, out_dim, bias=False)
        self.to_k = nn.Linear(self.dim, k_dim, bias=False)
        self.to_v = nn.Linear(self.dim, v_dim, bias=False)

        self.dropout = nn.Dropout(p=self.dropout_rate)

        # Feedforward
        self.fc = nn.Linear(out_dim, self.dim)

    def forward(
        self,
        x: Tensor,
        context: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        rotary_embedding: Optional[RotaryEmbedding] = None,
    ) -> Tensor:
        """Computes the attention."""
        b, device = x.shape[0], x.device

        q = self.to_q(x)
        q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
        k = self.to_k(context) if context is not None else self.to_k(x)
        v = self.to_v(context) if context is not None else self.to_v(x)

        if rotary_embedding is not None:
            q, k, v = map(lambda t: rotary_embedding.rotate(t), (q, k, v))

        energy = einsum("b h i d, b j d -> b h i j", q, k) * self.scale
        mask_value = -torch.finfo(energy.dtype).max
        energy = apply_input_mask(b, k, energy, mask, mask_value, device)
        if self.causal:
            energy = apply_causal_mask(energy, mask, mask_value, device)

        attn = F.softmax(energy, dim=-1)
        attn = self.dropout(attn)
        out = einsum("b h i j, b j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.fc(out)
        return out


def apply_input_mask(
    b: int,
    k: Tensor,
    energy: Tensor,
    mask: Optional[Tensor],
    mask_value: Tensor,
    device: str,
) -> Tensor:
    """Applies an input mask."""
    if mask is not None:
        k_mask = torch.ones((b, k.shape[-2]), device=device).bool()
        q_mask = rearrange(mask, "b i -> b () i ()")
        k_mask = rearrange(k_mask, "b j -> b () () j")
        input_mask = q_mask * k_mask

        energy = energy.masked_fill_(~input_mask, mask_value)
    return energy


def apply_causal_mask(
    energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
) -> Tensor:
    """Applies a causal mask to the energy tensor."""
    i, j = energy.shape[-2:]
    r = torch.arange(i, device=device)
    mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
    mask = F.pad(mask, (j - i, 0), value=False)
    energy.masked_fill_(mask, mask_value)
    return energy