diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-11 22:14:58 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-11 22:14:58 +0200 | 
| commit | c90a69b41493a280085b373420031184886fefa1 (patch) | |
| tree | 89c432e5d594438b2c24301a4aec57c0bda93b6f | |
| parent | 3a4fa0f1cf3c381a3d81f39e6e9439ec192b2509 (diff) | |
Add MaMMUT network
| -rw-r--r-- | text_recognizer/network/mammut.py | 99 | 
1 files changed, 99 insertions, 0 deletions
diff --git a/text_recognizer/network/mammut.py b/text_recognizer/network/mammut.py new file mode 100644 index 0000000..761e5f9 --- /dev/null +++ b/text_recognizer/network/mammut.py @@ -0,0 +1,99 @@ +from typing import Type + +from einops import repeat +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.network.transformer.decoder import Decoder +from text_recognizer.network.transformer.attention import Attention + + +class ToLatent(nn.Module): +    def __init__(self, dim: int, dim_latent: int) -> None: +        super().__init__() +        self.to_latent = nn.Linear(dim, dim_latent, bias=False) + +    def forward(self, x: Tensor) -> Tensor: +        return F.normalize(self.to_latent(x), dim=-1) + + +class MaMMUT(nn.Module): +    def __init__( +        self, +        encoder: Type[nn.Module], +        image_attn_pool: Attention, +        decoder: Decoder, +        dim: int, +        dim_latent: int, +        num_tokens: int, +        pad_index: int, +        num_image_queries: int = 256, +    ) -> None: +        super().__init__() +        # Image encoder +        self.encoder = encoder +        self.image_queries = nn.Parameter(torch.randn(num_image_queries + 1, dim)) +        self.image_attn_pool = image_attn_pool +        self.image_attn_pool_norm = nn.LayerNorm(dim) + +        # Text decoder +        self.decoder = decoder +        self.pad_index = pad_index +        self.token_embeddings = nn.Embedding(num_tokens, dim) +        self.text_cls_token = nn.Parameter(torch.randn(dim)) +        self.text_cls_norm = nn.LayerNorm(dim) + +        # To latents +        self.to_image_latent = ToLatent(dim, dim_latent) +        self.to_text_latent = ToLatent(dim, dim_latent) + +        self.to_logits = nn.Sequential( +            nn.LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False) +        ) +        self.to_logits[-1].weight = self.token_embeddings.weight +        nn.init.normal_(self.token_embeddings.weight, std=0.02) + +    def to_image_features(self, images: Tensor) -> Tensor: +        image_tokens = self.encoder(images) +        queries = repeat(self.image_queries, "n d -> b n d", b=image_tokens.shape[0]) +        queries = self.image_attn_pool(queries, image_tokens) +        queries = self.image_attn_pool_norm(queries) +        return queries[:, 0], queries[:, 1:] + +    def to_text_cls_features(self, text: Tensor) -> Tensor: +        b = text.shape[0] +        text = text.long() +        mask = text != self.pad_index +        tokens = self.token_embeddings(text) + +        # Append cls token +        cls_tokens = repeat(self.text_cls_token, "d -> b 1 d", b=b) +        tokens = torch.cat([tokens, cls_tokens], dim=-2) + +        # Create input mask +        mask = F.pad(mask, pad=(0, 1), value=True) + +        features = self.decoder.self_attn(tokens, mask) +        cls_embedding = features[:, -1] +        return self.text_cls_norm(cls_embedding) + +    def to_latents(self, image_embeddings: Tensor, text_embeddings: Tensor) -> Tensor: +        image_latent = self.to_image_latent(image_embeddings) +        text_latent = self.to_text_latent(text_embeddings) +        return image_latent, text_latent + +    def decode(self, text: Tensor, image_features: Tensor) -> Tensor: +        text = text.long() +        mask = text != self.pad_index +        tokens = self.token_embeddings(text) +        text_features = self.decoder(tokens, context=image_features, mask=mask) +        return self.to_logits(text_features) + +    def encode(self, images: Tensor) -> Tensor: +        _, image_features = self.to_image_features(images) +        return image_features + +    def forward(self, images: Tensor, text: Tensor) -> Tensor: +        image_features = self.encode(images) +        return self.decode(text, image_features)  |