summaryrefslogtreecommitdiff
path: root/text_recognizer/network/mammut.py
blob: 761e5f93228bfb22156f0ac1540f7fd44ac8cc43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Type

from einops import repeat
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from text_recognizer.network.transformer.decoder import Decoder
from text_recognizer.network.transformer.attention import Attention


class ToLatent(nn.Module):
    def __init__(self, dim: int, dim_latent: int) -> None:
        super().__init__()
        self.to_latent = nn.Linear(dim, dim_latent, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        return F.normalize(self.to_latent(x), dim=-1)


class MaMMUT(nn.Module):
    def __init__(
        self,
        encoder: Type[nn.Module],
        image_attn_pool: Attention,
        decoder: Decoder,
        dim: int,
        dim_latent: int,
        num_tokens: int,
        pad_index: int,
        num_image_queries: int = 256,
    ) -> None:
        super().__init__()
        # Image encoder
        self.encoder = encoder
        self.image_queries = nn.Parameter(torch.randn(num_image_queries + 1, dim))
        self.image_attn_pool = image_attn_pool
        self.image_attn_pool_norm = nn.LayerNorm(dim)

        # Text decoder
        self.decoder = decoder
        self.pad_index = pad_index
        self.token_embeddings = nn.Embedding(num_tokens, dim)
        self.text_cls_token = nn.Parameter(torch.randn(dim))
        self.text_cls_norm = nn.LayerNorm(dim)

        # To latents
        self.to_image_latent = ToLatent(dim, dim_latent)
        self.to_text_latent = ToLatent(dim, dim_latent)

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)
        )
        self.to_logits[-1].weight = self.token_embeddings.weight
        nn.init.normal_(self.token_embeddings.weight, std=0.02)

    def to_image_features(self, images: Tensor) -> Tensor:
        image_tokens = self.encoder(images)
        queries = repeat(self.image_queries, "n d -> b n d", b=image_tokens.shape[0])
        queries = self.image_attn_pool(queries, image_tokens)
        queries = self.image_attn_pool_norm(queries)
        return queries[:, 0], queries[:, 1:]

    def to_text_cls_features(self, text: Tensor) -> Tensor:
        b = text.shape[0]
        text = text.long()
        mask = text != self.pad_index
        tokens = self.token_embeddings(text)

        # Append cls token
        cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=b)
        tokens = torch.cat([tokens, cls_tokens], dim=-2)

        # Create input mask
        mask = F.pad(mask, pad=(0, 1), value=True)

        features = self.decoder.self_attn(tokens, mask)
        cls_embedding = features[:, -1]
        return self.text_cls_norm(cls_embedding)

    def to_latents(self, image_embeddings: Tensor, text_embeddings: Tensor) -> Tensor:
        image_latent = self.to_image_latent(image_embeddings)
        text_latent = self.to_text_latent(text_embeddings)
        return image_latent, text_latent

    def decode(self, text: Tensor, image_features: Tensor) -> Tensor:
        text = text.long()
        mask = text != self.pad_index
        tokens = self.token_embeddings(text)
        text_features = self.decoder(tokens, context=image_features, mask=mask)
        return self.to_logits(text_features)

    def encode(self, images: Tensor) -> Tensor:
        _, image_features = self.to_image_features(images)
        return image_features

    def forward(self, images: Tensor, text: Tensor) -> Tensor:
        image_features = self.encode(images)
        return self.decode(text, image_features)