diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
commit | c9c60678673e19ad3367339eb8e7a093e5a98474 (patch) | |
tree | b787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer/nystromer | |
parent | a2a3133ed5da283888efbdb9924d0e3733c274c8 (diff) |
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer/nystromer')
-rw-r--r-- | text_recognizer/networks/transformer/nystromer/attention.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py index c2871fb..5ab19cf 100644 --- a/text_recognizer/networks/transformer/nystromer/attention.py +++ b/text_recognizer/networks/transformer/nystromer/attention.py @@ -157,14 +157,14 @@ class NystromAttention(nn.Module): self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the Nystrom attention.""" - _, n, _, h, m = x.shape, self.num_heads + _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks if n % m != 0: x, mask = self._pad_sequence(x, mask, n, m) q, k, v = self.qkv_fn(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - q *= self.scale + q = q * self.scale out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn) |