summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/attention.py74
-rw-r--r--text_recognizer/networks/vqvae/decoder.py4
-rw-r--r--text_recognizer/networks/vqvae/encoder.py4
-rw-r--r--text_recognizer/networks/vqvae/norm.py20
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py2
-rw-r--r--text_recognizer/networks/vqvae/resize.py19
6 files changed, 116 insertions, 7 deletions
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
new file mode 100644
index 0000000..5a6b3ce
--- /dev/null
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -0,0 +1,74 @@
+"""Convolutional attention block."""
+import attr
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.vqvae.norm import Normalize
+
+
+@attr.s
+class Attention(nn.Module):
+ """Convolutional attention."""
+
+ in_channels: int = attr.ib()
+ q: nn.Conv2d = attr.ib(init=False)
+ k: nn.Conv2d = attr.ib(init=False)
+ v: nn.Conv2d = attr.ib(init=False)
+ proj: nn.Conv2d = attr.ib(init=False)
+ norm: Normalize = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ super().__init__()
+ self.q = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.k = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.v = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.norm = Normalize(num_channels=self.in_channels)
+ self.proj = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies attention to feature maps."""
+ residual = x
+ x = self.norm(x)
+ q = self.q(x)
+ k = self.k(x)
+ v = self.v(x)
+
+ # Attention
+ B, C, H, W = q.shape
+ q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ k = k.reshape(B, C, H * W) # [B, C, HW]
+ energy = torch.bmm(q, k) * (C ** -0.5)
+ attention = F.softmax(energy, dim=2)
+
+ # Compute attention to which values
+ v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ out = torch.bmm(v, attention)
+ out = out.reshape(B, C, H, W)
+ out = self.proj(out)
+ return out + residual
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 3f59f0d..fcf768b 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -19,11 +19,9 @@ class Decoder(nn.Module):
activation: str = attr.ib()
decoder: nn.Sequential = attr.ib(init=False)
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ super().__init__()
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index e480545..f086c6b 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -21,11 +21,9 @@ class Encoder(nn.Module):
activation: str = attr.ib()
encoder: nn.Sequential = attr.ib(init=False)
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ super().__init__()
self.encoder = self._build_compression_block()
def _build_compression_block(self) -> nn.Sequential:
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
new file mode 100644
index 0000000..df66efc
--- /dev/null
+++ b/text_recognizer/networks/vqvae/norm.py
@@ -0,0 +1,20 @@
+"""Normalizer block."""
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class Normalize(nn.Module):
+ num_channels: int = attr.ib()
+ norm: nn.GroupNorm = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ super().__init__()
+ self.norm = nn.GroupNorm(
+ num_groups=32, num_channels=self.num_channels, eps=1.0e-6, affine=True
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies group normalization."""
+ return self.norm(x)
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index 1b59e78..a4f11f0 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -7,7 +7,7 @@ from einops import rearrange
import torch
from torch import nn
from torch import Tensor
-from torch.nn import functional as F
+import torch.nn.functional as F
class EmbeddingEMA(nn.Module):
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
new file mode 100644
index 0000000..769d089
--- /dev/null
+++ b/text_recognizer/networks/vqvae/resize.py
@@ -0,0 +1,19 @@
+"""Up and down-sample with linear interpolation."""
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+
+class Upsample(nn.Module):
+ """Upsamples by a factor 2."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies upsampling."""
+ return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class Downsample(nn.Module):
+ """Downsampling by a factor 2."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Applies downsampling."""
+ return F.avg_pool2d(x, kernel_size=2, stride=2)