summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/attention.py')
-rw-r--r--text_recognizer/networks/vqvae/attention.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
new file mode 100644
index 0000000..5a6b3ce
--- /dev/null
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -0,0 +1,74 @@
+"""Convolutional attention block."""
+import attr
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.vqvae.norm import Normalize
+
+
+@attr.s
+class Attention(nn.Module):
+ """Convolutional attention."""
+
+ in_channels: int = attr.ib()
+ q: nn.Conv2d = attr.ib(init=False)
+ k: nn.Conv2d = attr.ib(init=False)
+ v: nn.Conv2d = attr.ib(init=False)
+ proj: nn.Conv2d = attr.ib(init=False)
+ norm: Normalize = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ super().__init__()
+ self.q = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.k = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.v = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.norm = Normalize(num_channels=self.in_channels)
+ self.proj = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies attention to feature maps."""
+ residual = x
+ x = self.norm(x)
+ q = self.q(x)
+ k = self.k(x)
+ v = self.v(x)
+
+ # Attention
+ B, C, H, W = q.shape
+ q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ k = k.reshape(B, C, H * W) # [B, C, HW]
+ energy = torch.bmm(q, k) * (C ** -0.5)
+ attention = F.softmax(energy, dim=2)
+
+ # Compute attention to which values
+ v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ out = torch.bmm(v, attention)
+ out = out.reshape(B, C, H, W)
+ out = self.proj(out)
+ return out + residual