summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/nystromer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
commitc9c60678673e19ad3367339eb8e7a093e5a98474 (patch)
treeb787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer/nystromer/attention.py
parenta2a3133ed5da283888efbdb9924d0e3733c274c8 (diff)
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer/nystromer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/nystromer/attention.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py
index c2871fb..5ab19cf 100644
--- a/text_recognizer/networks/transformer/nystromer/attention.py
+++ b/text_recognizer/networks/transformer/nystromer/attention.py
@@ -157,14 +157,14 @@ class NystromAttention(nn.Module):
self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute the Nystrom attention."""
- _, n, _, h, m = x.shape, self.num_heads
+ _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks
if n % m != 0:
x, mask = self._pad_sequence(x, mask, n, m)
q, k, v = self.qkv_fn(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
- q *= self.scale
+ q = q * self.scale
out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn)