From 865999f42a83923bf9f72d0c5b7e0f9a7437c054 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 5 Nov 2021 19:26:58 +0100
Subject: Remove conv attention

---
 text_recognizer/networks/vqvae/attention.py | 75 -----------------------------
 1 file changed, 75 deletions(-)
 delete mode 100644 text_recognizer/networks/vqvae/attention.py

diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
deleted file mode 100644
index 78a2cc9..0000000
--- a/text_recognizer/networks/vqvae/attention.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Convolutional attention block."""
-import attr
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-@attr.s(eq=False)
-class Attention(nn.Module):
-    """Convolutional attention."""
-
-    in_channels: int = attr.ib()
-    q: nn.Conv2d = attr.ib(init=False)
-    k: nn.Conv2d = attr.ib(init=False)
-    v: nn.Conv2d = attr.ib(init=False)
-    proj: nn.Conv2d = attr.ib(init=False)
-    norm: Normalize = attr.ib(init=False)
-
-    def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
-        super().__init__()
-        self.q = nn.Conv2d(
-            in_channels=self.in_channels,
-            out_channels=self.in_channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-        self.k = nn.Conv2d(
-            in_channels=self.in_channels,
-            out_channels=self.in_channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-        self.v = nn.Conv2d(
-            in_channels=self.in_channels,
-            out_channels=self.in_channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-        self.norm = Normalize(num_channels=self.in_channels)
-        self.proj = nn.Conv2d(
-            in_channels=self.in_channels,
-            out_channels=self.in_channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Applies attention to feature maps."""
-        residual = x
-        x = self.norm(x)
-        q = self.q(x)
-        k = self.k(x)
-        v = self.v(x)
-
-        # Attention
-        B, C, H, W = q.shape
-        q = q.reshape(B, C, H * W).permute(0, 2, 1)  # [B, HW, C]
-        k = k.reshape(B, C, H * W)  # [B, C, HW]
-        energy = torch.bmm(q, k) * (int(C) ** -0.5)
-        attention = F.softmax(energy, dim=2)
-
-        # Compute attention to which values
-        v = v.reshape(B, C, H * W)
-        attention = attention.permute(0, 2, 1)  # [B, HW, HW]
-        out = torch.bmm(v, attention)
-        out = out.reshape(B, C, H, W)
-        out = self.proj(out)
-        return out + residual
-- 
cgit v1.2.3-70-g09d2