summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/attention.py')
-rw-r--r--text_recognizer/networks/vqvae/attention.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
index 5a6b3ce..78a2cc9 100644
--- a/text_recognizer/networks/vqvae/attention.py
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from text_recognizer.networks.vqvae.norm import Normalize
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Convolutional attention."""
@@ -63,11 +63,12 @@ class Attention(nn.Module):
B, C, H, W = q.shape
q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
k = k.reshape(B, C, H * W) # [B, C, HW]
- energy = torch.bmm(q, k) * (C ** -0.5)
+ energy = torch.bmm(q, k) * (int(C) ** -0.5)
attention = F.softmax(energy, dim=2)
# Compute attention to which values
- v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ v = v.reshape(B, C, H * W)
+ attention = attention.permute(0, 2, 1) # [B, HW, HW]
out = torch.bmm(v, attention)
out = out.reshape(B, C, H, W)
out = self.proj(out)