diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/attention.py')
-rw-r--r-- | text_recognizer/networks/vqvae/attention.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py index 5a6b3ce..78a2cc9 100644 --- a/text_recognizer/networks/vqvae/attention.py +++ b/text_recognizer/networks/vqvae/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from text_recognizer.networks.vqvae.norm import Normalize -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Convolutional attention.""" @@ -63,11 +63,12 @@ class Attention(nn.Module): B, C, H, W = q.shape q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (C ** -0.5) + energy = torch.bmm(q, k) * (int(C) ** -0.5) attention = F.softmax(energy, dim=2) # Compute attention to which values - v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] + v = v.reshape(B, C, H * W) + attention = attention.permute(0, 2, 1) # [B, HW, HW] out = torch.bmm(v, attention) out = out.reshape(B, C, H, W) out = self.proj(out) |