summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/perceiver')
-rw-r--r--text_recognizer/networks/perceiver/__init__.py0
-rw-r--r--text_recognizer/networks/perceiver/attention.py48
-rw-r--r--text_recognizer/networks/perceiver/perceiver.py89
3 files changed, 137 insertions, 0 deletions
diff --git a/text_recognizer/networks/perceiver/__init__.py b/text_recognizer/networks/perceiver/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/networks/perceiver/__init__.py
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
new file mode 100644
index 0000000..66aeaa8
--- /dev/null
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -0,0 +1,48 @@
+"""Attention module."""
+from typing import Optional
+
+from einops import rearrange, repeat
+import torch
+from torch import einsum, nn, Tensor
+import torch.nn.functional as F
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ context_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ ) -> None:
+ super().__init__()
+ inner_dim = heads * dim_head
+ context_dim = context_dim if context_dim is not None else query_dim
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
+ self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
+
+ def forward(
+ self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor]
+ ) -> Tensor:
+ h = self.heads
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_val = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~mask, max_neg_val)
+
+ attn = sim.softmax(dim=-1)
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)
diff --git a/text_recognizer/networks/perceiver/perceiver.py b/text_recognizer/networks/perceiver/perceiver.py
new file mode 100644
index 0000000..65ee20c
--- /dev/null
+++ b/text_recognizer/networks/perceiver/perceiver.py
@@ -0,0 +1,89 @@
+"""Perceiver IO.
+
+A copy from lucidrains.
+"""
+from itertools import repeat
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from text_recognizer.networks.perceiver.attention import Attention
+from text_recognizer.networks.transformer.ff import FeedForward
+from text_recognizer.networks.transformer.norm import PreNorm
+
+
+class PerceiverIO(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ cross_heads: int,
+ cross_head_dim: int,
+ num_latents: int,
+ latent_dim: int,
+ latent_heads: int,
+ depth: int,
+ queries_dim: int,
+ logits_dim: int,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
+
+ self.cross_attn_block = nn.ModuleList(
+ [
+ PreNorm(
+ latent_dim,
+ Attention(
+ latent_dim, dim, heads=cross_heads, dim_head=cross_head_dim
+ ),
+ context_dim=dim,
+ ),
+ PreNorm(latent_dim, FeedForward(latent_dim)),
+ ]
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ [
+ PreNorm(
+ latent_dim,
+ Attention(latent_dim, heads=latent_heads, dim_head=latent_dim),
+ ),
+ PreNorm(latent_dim, FeedForward(latent_dim)),
+ ]
+ for _ in range(depth)
+ ]
+ )
+
+ self.decoder_cross_attn = PreNorm(
+ queries_dim,
+ Attention(
+ queries_dim, latent_dim, heads=cross_heads, dim_head=cross_head_dim
+ ),
+ context_dim=latent_dim,
+ )
+ self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim))
+ self.to_logits = nn.Linear(queries_dim, logits_dim)
+
+ def forward(
+ self, data: Tensor, queries: Tensor, mask: Optional[Tensor] = None
+ ) -> Tensor:
+ b = data.shape[0]
+ x = repeat(self.latents, "nd -> bnd", b=b)
+
+ cross_attn, cross_ff = self.cross_attn_block
+
+ x = cross_attn(x, context=data, mask=mask) + x
+ x = cross_ff(x) + x
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ if queries.ndim == 2:
+ queries = repeat(queries, "nd->bnd", b=b)
+
+ latents = self.decoder_cross_attn(queries, context=x)
+ latents = latents + self.decoder_ff(latents)
+
+ return self.to_logits(latents)