diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/attention.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/attention.py')
-rw-r--r-- | text_recognizer/networks/vqvae/attention.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py index 5a6b3ce..78a2cc9 100644 --- a/text_recognizer/networks/vqvae/attention.py +++ b/text_recognizer/networks/vqvae/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from text_recognizer.networks.vqvae.norm import Normalize -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Convolutional attention.""" @@ -63,11 +63,12 @@ class Attention(nn.Module): B, C, H, W = q.shape q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (C ** -0.5) + energy = torch.bmm(q, k) * (int(C) ** -0.5) attention = F.softmax(energy, dim=2) # Compute attention to which values - v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] + v = v.reshape(B, C, H * W) + attention = attention.permute(0, 2, 1) # [B, HW, HW] out = torch.bmm(v, attention) out = out.reshape(B, C, H, W) out = self.proj(out) |