summaryrefslogtreecommitdiff
path: root/text_recognizer/network
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network')
-rw-r--r--text_recognizer/network/mammut.py99
1 files changed, 99 insertions, 0 deletions
diff --git a/text_recognizer/network/mammut.py b/text_recognizer/network/mammut.py
new file mode 100644
index 0000000..761e5f9
--- /dev/null
+++ b/text_recognizer/network/mammut.py
@@ -0,0 +1,99 @@
+from typing import Type
+
+from einops import repeat
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.network.transformer.decoder import Decoder
+from text_recognizer.network.transformer.attention import Attention
+
+
+class ToLatent(nn.Module):
+ def __init__(self, dim: int, dim_latent: int) -> None:
+ super().__init__()
+ self.to_latent = nn.Linear(dim, dim_latent, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return F.normalize(self.to_latent(x), dim=-1)
+
+
+class MaMMUT(nn.Module):
+ def __init__(
+ self,
+ encoder: Type[nn.Module],
+ image_attn_pool: Attention,
+ decoder: Decoder,
+ dim: int,
+ dim_latent: int,
+ num_tokens: int,
+ pad_index: int,
+ num_image_queries: int = 256,
+ ) -> None:
+ super().__init__()
+ # Image encoder
+ self.encoder = encoder
+ self.image_queries = nn.Parameter(torch.randn(num_image_queries + 1, dim))
+ self.image_attn_pool = image_attn_pool
+ self.image_attn_pool_norm = nn.LayerNorm(dim)
+
+ # Text decoder
+ self.decoder = decoder
+ self.pad_index = pad_index
+ self.token_embeddings = nn.Embedding(num_tokens, dim)
+ self.text_cls_token = nn.Parameter(torch.randn(dim))
+ self.text_cls_norm = nn.LayerNorm(dim)
+
+ # To latents
+ self.to_image_latent = ToLatent(dim, dim_latent)
+ self.to_text_latent = ToLatent(dim, dim_latent)
+
+ self.to_logits = nn.Sequential(
+ nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)
+ )
+ self.to_logits[-1].weight = self.token_embeddings.weight
+ nn.init.normal_(self.token_embeddings.weight, std=0.02)
+
+ def to_image_features(self, images: Tensor) -> Tensor:
+ image_tokens = self.encoder(images)
+ queries = repeat(self.image_queries, "n d -> b n d", b=image_tokens.shape[0])
+ queries = self.image_attn_pool(queries, image_tokens)
+ queries = self.image_attn_pool_norm(queries)
+ return queries[:, 0], queries[:, 1:]
+
+ def to_text_cls_features(self, text: Tensor) -> Tensor:
+ b = text.shape[0]
+ text = text.long()
+ mask = text != self.pad_index
+ tokens = self.token_embeddings(text)
+
+ # Append cls token
+ cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=b)
+ tokens = torch.cat([tokens, cls_tokens], dim=-2)
+
+ # Create input mask
+ mask = F.pad(mask, pad=(0, 1), value=True)
+
+ features = self.decoder.self_attn(tokens, mask)
+ cls_embedding = features[:, -1]
+ return self.text_cls_norm(cls_embedding)
+
+ def to_latents(self, image_embeddings: Tensor, text_embeddings: Tensor) -> Tensor:
+ image_latent = self.to_image_latent(image_embeddings)
+ text_latent = self.to_text_latent(text_embeddings)
+ return image_latent, text_latent
+
+ def decode(self, text: Tensor, image_features: Tensor) -> Tensor:
+ text = text.long()
+ mask = text != self.pad_index
+ tokens = self.token_embeddings(text)
+ text_features = self.decoder(tokens, context=image_features, mask=mask)
+ return self.to_logits(text_features)
+
+ def encode(self, images: Tensor) -> Tensor:
+ _, image_features = self.to_image_features(images)
+ return image_features
+
+ def forward(self, images: Tensor, text: Tensor) -> Tensor:
+ image_features = self.encode(images)
+ return self.decode(text, image_features)