summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:45:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 18:45:58 +0200
commitfd9b1570c568d9ce8f1ac7258f05f9977a5cc9c8 (patch)
treeb347be4efd051a9ac9edbc50ff0c48c92596ca64 /text_recognizer/networks/perceiver/attention.py
parent8b2e5296b290f147935c58207fbfd9674394c7b3 (diff)
Remove perceiver net
Diffstat (limited to 'text_recognizer/networks/perceiver/attention.py')
-rw-r--r--text_recognizer/networks/perceiver/attention.py40
1 files changed, 0 insertions, 40 deletions
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
deleted file mode 100644
index 19e3e17..0000000
--- a/text_recognizer/networks/perceiver/attention.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Attention module."""
-from typing import Optional
-
-from einops import rearrange, repeat
-import torch
-from torch import einsum, nn, Tensor
-import torch.nn.functional as F
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- query_dim: int,
- context_dim: Optional[int] = None,
- heads: int = 8,
- dim_head: int = 64,
- ) -> None:
- super().__init__()
- inner_dim = heads * dim_head
- context_dim = context_dim if context_dim is not None else query_dim
- self.scale = dim_head ** -0.5
- self.heads = heads
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
- self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
-
- def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor:
- h = self.heads
- q = self.to_q(x)
- context = context if context is not None else x
- k, v = self.to_kv(context).chunk(2, dim=-1)
-
- q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
-
- attn = sim.softmax(dim=-1)
- out = einsum("b i j, b j d -> b i d", attn, v)
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
- return self.to_out(out)