summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver/attention.py
blob: 19e3e17b9eb42b61ee2ef49521c3f449ab07eee5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Attention module."""
from typing import Optional

from einops import rearrange, repeat
import torch
from torch import einsum, nn, Tensor
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        context_dim: Optional[int] = None,
        heads: int = 8,
        dim_head: int = 64,
    ) -> None:
        super().__init__()
        inner_dim = heads * dim_head
        context_dim = context_dim if context_dim is not None else query_dim
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim, bias=False)

    def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor:
        h = self.heads
        q = self.to_q(x)
        context = context if context is not None else x
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        attn = sim.softmax(dim=-1)
        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.to_out(out)