summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-03 00:55:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-03 00:55:14 +0200
commit5dd76ca9a3ff35c57cbc7c607afbdb4ee1c8b36f (patch)
tree0d27d99fb3484fd6d871549c465d548265736d67 /text_recognizer/networks/perceiver/attention.py
parent9d136d11c6edbae2539e3b4a952bd39e9dbcaa68 (diff)
Add perceiver
Diffstat (limited to 'text_recognizer/networks/perceiver/attention.py')
-rw-r--r--text_recognizer/networks/perceiver/attention.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
new file mode 100644
index 0000000..66aeaa8
--- /dev/null
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -0,0 +1,48 @@
+"""Attention module."""
+from typing import Optional
+
+from einops import rearrange, repeat
+import torch
+from torch import einsum, nn, Tensor
+import torch.nn.functional as F
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ context_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ ) -> None:
+ super().__init__()
+ inner_dim = heads * dim_head
+ context_dim = context_dim if context_dim is not None else query_dim
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
+ self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
+
+ def forward(
+ self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor]
+ ) -> Tensor:
+ h = self.heads
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_val = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ sim.masked_fill_(~mask, max_neg_val)
+
+ attn = sim.softmax(dim=-1)
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+ return self.to_out(out)