From 5dd76ca9a3ff35c57cbc7c607afbdb4ee1c8b36f Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 3 Sep 2022 00:55:14 +0200
Subject: Add perceiver

---
 text_recognizer/networks/perceiver/__init__.py  |  0
 text_recognizer/networks/perceiver/attention.py | 48 +++++++++++++
 text_recognizer/networks/perceiver/perceiver.py | 89 +++++++++++++++++++++++++
 3 files changed, 137 insertions(+)
 create mode 100644 text_recognizer/networks/perceiver/__init__.py
 create mode 100644 text_recognizer/networks/perceiver/attention.py
 create mode 100644 text_recognizer/networks/perceiver/perceiver.py

(limited to 'text_recognizer/networks/perceiver')

diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
new file mode 100644
index 0000000..66aeaa8
--- /dev/null
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -0,0 +1,48 @@
+"""Attention module."""
+from typing import Optional
+
+from einops import rearrange, repeat
+import torch
+from torch import einsum, nn, Tensor
+import torch.nn.functional as F
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        query_dim: int,
+        context_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+    ) -> None:
+        super().__init__()
+        inner_dim = heads * dim_head
+        context_dim = context_dim if context_dim is not None else query_dim
+        self.scale = dim_head ** -0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
+        self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
+
+    def forward(
+        self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor]
+    ) -> Tensor:
+        h = self.heads
+        q = self.to_q(x)
+        context = context if context is not None else x
+        k, v = self.to_kv(context).chunk(2, dim=-1)
+
+        q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+        if mask is not None:
+            mask = rearrange(mask, "b ... -> b (...)")
+            max_neg_val = -torch.finfo(sim.dtype).max
+            mask = repeat(mask, "b j -> (b h) () j", h=h)
+            sim.masked_fill_(~mask, max_neg_val)
+
+        attn = sim.softmax(dim=-1)
+        out = einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+        return self.to_out(out)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
new file mode 100644
index 0000000..65ee20c
--- /dev/null
+++ b/text_recognizer/networks/perceiver/perceiver.py
@@ -0,0 +1,89 @@
+"""Perceiver IO.
+
+A copy from lucidrains.
+"""
+from itertools import repeat
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from text_recognizer.networks.perceiver.attention import Attention
+from text_recognizer.networks.transformer.ff import FeedForward
+from text_recognizer.networks.transformer.norm import PreNorm
+
+
+class PerceiverIO(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        cross_heads: int,
+        cross_head_dim: int,
+        num_latents: int,
+        latent_dim: int,
+        latent_heads: int,
+        depth: int,
+        queries_dim: int,
+        logits_dim: int,
+    ) -> None:
+        super().__init__()
+        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
+
+        self.cross_attn_block = nn.ModuleList(
+            [
+                PreNorm(
+                    latent_dim,
+                    Attention(
+                        latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim
+                    ),
+                    context_dim=dim,
+                ),
+                PreNorm(latent_dim, FeedForward(latent_dim)),
+            ]
+        )
+
+        self.layers = nn.ModuleList(
+            [
+                [
+                    PreNorm(
+                        latent_dim,
+                        Attention(latent_dim, heads=latent_heads, dim_head=latent_dim),
+                    ),
+                    PreNorm(latent_dim, FeedForward(latent_dim)),
+                ]
+                for _ in range(depth)
+            ]
+        )
+
+        self.decoder_cross_attn = PreNorm(
+            queries_dim,
+            Attention(
+                queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim
+            ),
+            context_dim=latent_dim,
+        )
+        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim))
+        self.to_logits = nn.Linear(queries_dim, logits_dim)
+
+    def forward(
+        self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None
+    ) -> Tensor:
+        b = data.shape[0]
+        x = repeat(self.latents, "nd -> bnd", b=b)
+
+        cross_attn, cross_ff = self.cross_attn_block
+
+        x = cross_attn(x, context=data, mask=mask) + x
+        x = cross_ff(x) + x
+
+        for attn, ff in self.layers:
+            x = attn(x) + x
+            x = ff(x) + x
+
+        if queries.ndim == 2:
+            queries = repeat(queries, "nd->bnd", b=b)
+
+        latents = self.decoder_cross_attn(queries, context=x)
+        latents = latents + self.decoder_ff(latents)
+
+        return self.to_logits(latents)
-- 
cgit v1.2.3-70-g09d2