summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:04:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:04:55 +0200
commit4b112d4164d4de103997c5ddcadb802ab6440c8d (patch)
tree5394fefaf3a3c30fdd2ad714185e374a0ae8c6bb /text_recognizer/networks/perceiver
parentfc7fb0df5aa704aab3d73eab964631c8be924c42 (diff)
Remove mask from perceiver attention
Diffstat (limited to 'text_recognizer/networks/perceiver')
-rw-r--r--text_recognizer/networks/perceiver/attention.py10
1 files changed, 1 insertions, 9 deletions
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
index 0ee51b1..19e3e17 100644
--- a/text_recognizer/networks/perceiver/attention.py
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -25,9 +25,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False)
self.to_out = nn.Linear(inner_dim, query_dim, bias=False)
- def forward(
- self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor]
- ) -> Tensor:
+ def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor:
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
@@ -36,12 +34,6 @@ class Attention(nn.Module):
q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- # if mask is not None:
- # mask = rearrange(mask, "b ... -> b (...)")
- # max_neg_val = -torch.finfo(sim.dtype).max
- # mask = repeat(mask, "b j -> (b h) () j", h=h)
- # sim.masked_fill_(~mask, max_neg_val)
-
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)