diff options
Diffstat (limited to 'text_recognizer/networks/transformer/nystromer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/nystromer/attention.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py index c2871fb..5ab19cf 100644 --- a/text_recognizer/networks/transformer/nystromer/attention.py +++ b/text_recognizer/networks/transformer/nystromer/attention.py @@ -157,14 +157,14 @@ class NystromAttention(nn.Module): self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the Nystrom attention.""" - _, n, _, h, m = x.shape, self.num_heads + _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks if n % m != 0: x, mask = self._pad_sequence(x, mask, n, m) q, k, v = self.qkv_fn(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - q *= self.scale + q = q * self.scale out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn) |