diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-05 19:26:58 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-05 19:26:58 +0100 |
commit | 865999f42a83923bf9f72d0c5b7e0f9a7437c054 (patch) | |
tree | 5231578017c08158c54202cc1f502b8f118c78a6 /text_recognizer/networks | |
parent | e5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (diff) |
Remove conv attention
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vqvae/attention.py | 75 |
1 files changed, 0 insertions, 75 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py deleted file mode 100644 index 78a2cc9..0000000 --- a/text_recognizer/networks/vqvae/attention.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Convolutional attention block.""" -import attr -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.networks.vqvae.norm import Normalize - - -@attr.s(eq=False) -class Attention(nn.Module): - """Convolutional attention.""" - - in_channels: int = attr.ib() - q: nn.Conv2d = attr.ib(init=False) - k: nn.Conv2d = attr.ib(init=False) - v: nn.Conv2d = attr.ib(init=False) - proj: nn.Conv2d = attr.ib(init=False) - norm: Normalize = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - super().__init__() - self.q = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.k = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.v = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - self.norm = Normalize(num_channels=self.in_channels) - self.proj = nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=1, - stride=1, - padding=0, - ) - - def forward(self, x: Tensor) -> Tensor: - """Applies attention to feature maps.""" - residual = x - x = self.norm(x) - q = self.q(x) - k = self.k(x) - v = self.v(x) - - # Attention - B, C, H, W = q.shape - q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] - k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (int(C) ** -0.5) - attention = F.softmax(energy, dim=2) - - # Compute attention to which values - v = v.reshape(B, C, H * W) - attention = attention.permute(0, 2, 1) # [B, HW, HW] - out = torch.bmm(v, attention) - out = out.reshape(B, C, H, W) - out = self.proj(out) - return out + residual |