summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/perceiver/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/perceiver/attention.py')
-rw-r--r--text_recognizer/networks/perceiver/attention.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py
index 66aeaa8..0ee51b1 100644
--- a/text_recognizer/networks/perceiver/attention.py
+++ b/text_recognizer/networks/perceiver/attention.py
@@ -36,11 +36,11 @@ class Attention(nn.Module):
q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
- if mask is not None:
- mask = rearrange(mask, "b ... -> b (...)")
- max_neg_val = -torch.finfo(sim.dtype).max
- mask = repeat(mask, "b j -> (b h) () j", h=h)
- sim.masked_fill_(~mask, max_neg_val)
+ # if mask is not None:
+ # mask = rearrange(mask, "b ... -> b (...)")
+ # max_neg_val = -torch.finfo(sim.dtype).max
+ # mask = repeat(mask, "b j -> (b h) () j", h=h)
+ # sim.masked_fill_(~mask, max_neg_val)
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)