From 4b112d4164d4de103997c5ddcadb802ab6440c8d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Sep 2022 00:04:55 +0200 Subject: Remove mask from perceiver attention --- text_recognizer/networks/perceiver/attention.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/perceiver/attention.py b/text_recognizer/networks/perceiver/attention.py index 0ee51b1..19e3e17 100644 --- a/text_recognizer/networks/perceiver/attention.py +++ b/text_recognizer/networks/perceiver/attention.py @@ -25,9 +25,7 @@ class Attention(nn.Module): self.to_kv = nn.Linear(context_dim, 2 * inner_dim, bias=False) self.to_out = nn.Linear(inner_dim, query_dim, bias=False) - def forward( - self, x: Tensor, context: Optional[Tensor] = None, mask=Optional[Tensor] - ) -> Tensor: + def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: h = self.heads q = self.to_q(x) context = context if context is not None else x @@ -36,12 +34,6 @@ class Attention(nn.Module): q, v, k = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale - # if mask is not None: - # mask = rearrange(mask, "b ... -> b (...)") - # max_neg_val = -torch.finfo(sim.dtype).max - # mask = repeat(mask, "b j -> (b h) () j", h=h) - # sim.masked_fill_(~mask, max_neg_val) - attn = sim.softmax(dim=-1) out = einsum("b i j, b j d -> b i d", attn, v) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) -- cgit v1.2.3-70-g09d2