1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
from typing import Type
import torch
import torch.nn.functional as F
from einops import repeat
from torch import Tensor, nn
from text_recognizer.network.transformer.attention import Attention
from text_recognizer.network.transformer.decoder import Decoder
class ToLatent(nn.Module):
def __init__(self, dim: int, dim_latent: int) -> None:
super().__init__()
self.to_latent = nn.Linear(dim, dim_latent, bias=False)
def forward(self, x: Tensor) -> Tensor:
return F.normalize(self.to_latent(x), dim=-1)
class MaMMUT(nn.Module):
def __init__(
self,
encoder: Type[nn.Module],
image_attn_pool: Attention,
decoder: Decoder,
dim: int,
dim_latent: int,
num_tokens: int,
pad_index: int,
num_image_queries: int = 256,
) -> None:
super().__init__()
# Image encoder
self.encoder = encoder
self.image_queries = nn.Parameter(torch.randn(num_image_queries + 1, dim))
self.image_attn_pool = image_attn_pool
self.image_attn_pool_norm = nn.LayerNorm(dim)
# Text decoder
self.decoder = decoder
self.pad_index = pad_index
self.token_embeddings = nn.Embedding(num_tokens, dim)
self.text_cls_token = nn.Parameter(torch.randn(dim))
self.text_cls_norm = nn.LayerNorm(dim)
# To latents
self.to_image_latent = ToLatent(dim, dim_latent)
self.to_text_latent = ToLatent(dim, dim_latent)
self.to_logits = nn.Sequential(
nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)
)
self.to_logits[-1].weight = self.token_embeddings.weight
nn.init.normal_(self.token_embeddings.weight, std=0.02)
def to_image_features(self, images: Tensor) -> Tensor:
image_tokens = self.encoder(images)
queries = repeat(self.image_queries, "n d -> b n d", b=image_tokens.shape[0])
queries = self.image_attn_pool(queries, image_tokens)
queries = self.image_attn_pool_norm(queries)
return queries[:, 0], queries[:, 1:]
def to_text_cls_features(self, text: Tensor) -> Tensor:
b = text.shape[0]
text = text.long()
mask = text != self.pad_index
tokens = self.token_embeddings(text)
# Append cls token
cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=b)
tokens = torch.cat([tokens, cls_tokens], dim=-2)
# Create input mask
mask = F.pad(mask, pad=(0, 1), value=True)
features = self.decoder.self_attn(tokens, mask)
cls_embedding = features[:, -1]
return self.text_cls_norm(cls_embedding)
def to_latents(self, image_embeddings: Tensor, text_embeddings: Tensor) -> Tensor:
image_latent = self.to_image_latent(image_embeddings)
text_latent = self.to_text_latent(text_embeddings)
return image_latent, text_latent
def decode(self, text: Tensor, image_features: Tensor) -> Tensor:
text = text.long()
mask = text != self.pad_index
tokens = self.token_embeddings(text)
text_features = self.decoder(tokens, context=image_features, mask=mask)
return self.to_logits(text_features)
def encode(self, images: Tensor) -> Tensor:
_, image_features = self.to_image_features(images)
return image_features
def forward(self, images: Tensor, text: Tensor) -> Tensor:
image_features = self.encode(images)
return self.decode(text, image_features)
|