summaryrefslogtreecommitdiff
path: root/text_recognizer/network/convnext/transformer.py
blob: 6c53c4830f9ae3c18c5239fc852eaa2280f588d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Convolution self attention block."""

from einops import rearrange
from torch import Tensor, einsum, nn

from text_recognizer.network.convnext.norm import LayerNorm


class FeedForward(nn.Module):
    def __init__(self, dim: int, mult: int = 4) -> None:
        super().__init__()
        inner_dim = int(dim * mult)
        self.fn = nn.Sequential(
            LayerNorm(dim),
            nn.Conv2d(dim, inner_dim, 1, bias=False),
            nn.GELU(),
            LayerNorm(inner_dim),
            nn.Conv2d(inner_dim, dim, 1, bias=False),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.fn(x)


class Attention(nn.Module):
    def __init__(
        self, dim: int, heads: int = 4, dim_head: int = 64, scale: int = 8
    ) -> None:
        super().__init__()
        self.scale = scale
        self.heads = heads
        inner_dim = heads * dim_head
        self.norm = LayerNorm(dim)

        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        h, w = x.shape[-2:]

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) ... -> b h (...) c", h=self.heads),
            (q, k, v),
        )

        q = q * self.scale
        sim = einsum("b h i d, b h j d -> b h i j", q, k)
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, attn: Attention, ff: FeedForward) -> None:
        super().__init__()
        self.attn = attn
        self.ff = ff

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.attn(x)
        x = x + self.ff(x)
        return x