summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/convnext/attention.py
blob: a373c379916afaee29a5c0ed5d1deed65b1c796f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Convolution self attention block."""

import torch.nn.functional as F
from einops import rearrange, reduce
from torch import Tensor, einsum, nn

from text_recognizer.networks.convnext.norm import LayerNorm
from text_recognizer.networks.convnext.residual import Residual


def l2norm(t: Tensor) -> Tensor:
    return F.normalize(t, dim=-1)


class FeedForward(nn.Module):
    def __init__(self, dim: int, mult: int = 4) -> None:
        super().__init__()
        inner_dim = int(dim * mult)
        self.fn = Residual(
            nn.Sequential(
                LayerNorm(dim),
                nn.Conv2d(dim, inner_dim, 1, bias=False),
                nn.GELU(),
                LayerNorm(inner_dim),
                nn.Conv2d(inner_dim, dim, 1, bias=False),
            )
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.fn(x)


class Attention(nn.Module):
    def __init__(
        self, dim: int, heads: int = 4, dim_head: int = 64, scale: int = 8
    ) -> None:
        super().__init__()
        self.scale = scale
        self.heads = heads
        inner_dim = heads * dim_head
        self.norm = LayerNorm(dim)

        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        h, w = x.shape[-2:]

        residual = x.clone()

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) ... -> b h (...) c", h=self.heads),
            (q, k, v),
        )

        q, k = map(l2norm, (q, k))

        sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out) + residual


class TransformerBlock(nn.Module):
    def __init__(self, attn: Attention, ff: FeedForward) -> None:
        super().__init__()
        self.attn = attn
        self.ff = ff

    def forward(self, x: Tensor) -> Tensor:
        x = self.attn(x)
        x = self.ff(x)
        return x