From c9c60678673e19ad3367339eb8e7a093e5a98474 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 22:46:09 +0200 Subject: Reformatting of positional encodings and ViT working --- text_recognizer/networks/transformer/nystromer/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'text_recognizer/networks/transformer/nystromer') diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py index c2871fb..5ab19cf 100644 --- a/text_recognizer/networks/transformer/nystromer/attention.py +++ b/text_recognizer/networks/transformer/nystromer/attention.py @@ -157,14 +157,14 @@ class NystromAttention(nn.Module): self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the Nystrom attention.""" - _, n, _, h, m = x.shape, self.num_heads + _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks if n % m != 0: x, mask = self._pad_sequence(x, mask, n, m) q, k, v = self.qkv_fn(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - q *= self.scale + q = q * self.scale out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn) -- cgit v1.2.3-70-g09d2