summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/nystromer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/nystromer')
-rw-r--r--text_recognizer/networks/transformer/nystromer/attention.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py
index c2871fb..5ab19cf 100644
--- a/text_recognizer/networks/transformer/nystromer/attention.py
+++ b/text_recognizer/networks/transformer/nystromer/attention.py
@@ -157,14 +157,14 @@ class NystromAttention(nn.Module):
self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute the Nystrom attention."""
- _, n, _, h, m = x.shape, self.num_heads
+ _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks
if n % m != 0:
x, mask = self._pad_sequence(x, mask, n, m)
q, k, v = self.qkv_fn(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
- q *= self.scale
+ q = q * self.scale
out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn)