diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vqvae/attention.py | 75 |
1 files changed, 0 insertions, 75 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py deleted file mode 100644 index 78a2cc9..0000000 --- a/text_recognizer/networks/vqvae/attention.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Convolutional attention block.""" -import attr -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.networks.vqvae.norm import Normalize - - -@attr.s(eq=False) -class Attention(nn.Module): - """Convolutional attention.""" - - in_channels: int = attr.ib() - q: nn.Conv2d = attr.ib(init=False) - k: nn.Conv2d = attr.ib(init=False) - v: nn.Conv2d = attr.ib(init=False) - proj: nn.Conv2d = attr.ib(init=False) - norm: Normalize = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - super().__init__() - self.q = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.k = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.v = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.norm = Normalize(num_channels=self.in_channels) - self.proj = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - - def forward(self, x: Tensor) -> Tensor: - """Applies attention to feature maps.""" - residual = x - x = self.norm(x) - q = self.q(x) - k = self.k(x) - v = self.v(x) - - # Attention - B, C, H, W = q.shape - q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] - k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (int(C) ** -0.5) - attention = F.softmax(energy, dim=2) - - # Compute attention to which values - v = v.reshape(B, C, H * W) - attention = attention.permute(0, 2, 1) # [B, HW, HW] - out = torch.bmm(v, attention) - out = out.reshape(B, C, H, W) - out = self.proj(out) - return out + residual |