From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 6 Aug 2021 02:42:45 +0200
Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading

---
 text_recognizer/data/emnist_mapping.py       |  14 ++-
 text_recognizer/data/transforms.py           |   3 +-
 text_recognizer/networks/conv_transformer.py |   7 +-
 text_recognizer/networks/vq_transformer.py   |  50 ++++----
 text_recognizer/networks/vqvae/__init__.py   |   1 -
 text_recognizer/networks/vqvae/attention.py  |   7 +-
 text_recognizer/networks/vqvae/decoder.py    |  83 ++++++++------
 text_recognizer/networks/vqvae/encoder.py    |  82 ++++++-------
 text_recognizer/networks/vqvae/norm.py       |   4 +-
 text_recognizer/networks/vqvae/pixelcnn.py   | 165 +++++++++++++++++++++++++++
 text_recognizer/networks/vqvae/quantizer.py  |  15 ++-
 text_recognizer/networks/vqvae/residual.py   |  53 +++++++--
 text_recognizer/networks/vqvae/resize.py     |   2 +-
 text_recognizer/networks/vqvae/vqvae.py      |  98 ++++------------
 14 files changed, 372 insertions(+), 212 deletions(-)
 create mode 100644 text_recognizer/networks/vqvae/pixelcnn.py

(limited to 'text_recognizer')

diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 925d214..3e91594 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -9,15 +9,23 @@ from text_recognizer.data.emnist import emnist_mapping
 
 
 class EmnistMapping(AbstractMapping):
-    def __init__(self, extra_symbols: Optional[Set[str]] = None) -> None:
+    def __init__(self, extra_symbols: Optional[Set[str]] = None, lower: bool = True) -> None:
         self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
         self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
             self.extra_symbols
         )
+        if lower:
+            self._to_lower()
         super().__init__(self.input_size, self.mapping, self.inverse_mapping)
 
-    def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
+    def _to_lower(self) -> None:
+        """Converts mapping to lowercase letters only."""
+        def _filter(x: int) -> int:
+            if 40 <= x:
+                return x - 26
+            return x
+        self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)}
+        self.mapping = [c for c in self.mapping if not c.isupper()]
 
     def get_token(self, index: Union[int, Tensor]) -> str:
         if (index := int(index)) <= len(self.mapping):
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 047496f..51f52de 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,10 +1,11 @@
 """Transforms for PyTorch datasets."""
 from pathlib import Path
-from typing import Optional, Union, Set
+from typing import Optional, Union, Type, Set
 
 import torch
 from torch import Tensor
 
+from text_recognizer.data.base_mapping import AbstractMapping
 from text_recognizer.data.word_piece_mapping import WordPieceMapping
 
 
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index f3ba49d..b1a101e 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -4,7 +4,6 @@ from typing import Tuple
 
 from torch import nn, Tensor
 
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
 from text_recognizer.networks.transformer.layers import Decoder
 from text_recognizer.networks.transformer.positional_encodings import (
     PositionalEncoding,
@@ -18,15 +17,17 @@ class ConvTransformer(nn.Module):
     def __init__(
         self,
         input_dims: Tuple[int, int, int],
+        encoder_dim: int,
         hidden_dim: int,
         dropout_rate: float,
         num_classes: int,
         pad_index: Tensor,
-        encoder: EfficientNet,
+        encoder: nn.Module,
         decoder: Decoder,
     ) -> None:
         super().__init__()
         self.input_dims = input_dims
+        self.encoder_dim = encoder_dim
         self.hidden_dim = hidden_dim
         self.dropout_rate = dropout_rate
         self.num_classes = num_classes
@@ -38,7 +39,7 @@ class ConvTransformer(nn.Module):
         # positional encoding.
         self.latent_encoder = nn.Sequential(
             nn.Conv2d(
-                in_channels=self.encoder.out_channels,
+                in_channels=self.encoder_dim,
                 out_channels=self.hidden_dim,
                 kernel_size=1,
             ),
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index a972565..0433863 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,16 +1,12 @@
 """Vector quantized encoder, transformer decoder."""
-import math
 from typing import Tuple
 
-from torch import nn, Tensor
+import torch
+from torch import Tensor
 
-from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.vqvae.vqvae import VQVAE
 from text_recognizer.networks.conv_transformer import ConvTransformer
 from text_recognizer.networks.transformer.layers import Decoder
-from text_recognizer.networks.transformer.positional_encodings import (
-    PositionalEncoding,
-    PositionalEncoding2D,
-)
 
 
 class VqTransformer(ConvTransformer):
@@ -19,16 +15,18 @@ class VqTransformer(ConvTransformer):
     def __init__(
         self,
         input_dims: Tuple[int, int, int],
+        encoder_dim: int,
         hidden_dim: int,
         dropout_rate: float,
         num_classes: int,
         pad_index: Tensor,
-        encoder: EfficientNet,
+        encoder: VQVAE,
         decoder: Decoder,
+        pretrained_encoder_path: str,
     ) -> None:
-        # TODO: Load pretrained vqvae encoder.
         super().__init__(
             input_dims=input_dims,
+            encoder_dim=encoder_dim,
             hidden_dim=hidden_dim,
             dropout_rate=dropout_rate,
             num_classes=num_classes,
@@ -36,24 +34,19 @@ class VqTransformer(ConvTransformer):
             encoder=encoder,
             decoder=decoder,
         )
-        # Latent projector for down sampling number of filters and 2d
-        # positional encoding.
-        self.latent_encoder = nn.Sequential(
-            nn.Conv2d(
-                in_channels=self.encoder.out_channels,
-                out_channels=self.hidden_dim,
-                kernel_size=1,
-            ),
-            PositionalEncoding2D(
-                hidden_dim=self.hidden_dim,
-                max_h=self.input_dims[1],
-                max_w=self.input_dims[2],
-            ),
-            nn.Flatten(start_dim=2),
-        )
+        self.pretrained_encoder_path = pretrained_encoder_path
+
+        # For typing
+        self.encoder: VQVAE
+
+    def setup_encoder(self) -> None:
+        """Remove unecessary layers."""
+        # TODO: load pretrained vqvae
+        del self.encoder.decoder
+        del self.encoder.post_codebook_conv
 
     def encode(self, x: Tensor) -> Tensor:
-        """Encodes an image into a latent feature vector.
+        """Encodes an image into a discrete (VQ) latent representation.
 
         Args:
             x (Tensor): Image tensor.
@@ -69,8 +62,11 @@ class VqTransformer(ConvTransformer):
         Returns:
             Tensor: A Latent embedding of the image.
         """
-        z = self.encoder(x)
-        z = self.latent_encoder(z)
+        with torch.no_grad():
+            z_e = self.encoder.encode(x)
+            z_q, _ = self.encoder.quantize(z_e)
+
+        z = self.latent_encoder(z_q)
 
         # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
         z = z.permute(0, 2, 1)
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 7d56bdb..e1f05fa 100644
--- a/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -1,2 +1 @@
 """VQ-VAE module."""
-from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
index 5a6b3ce..78a2cc9 100644
--- a/text_recognizer/networks/vqvae/attention.py
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
 from text_recognizer.networks.vqvae.norm import Normalize
 
 
-@attr.s
+@attr.s(eq=False)
 class Attention(nn.Module):
     """Convolutional attention."""
 
@@ -63,11 +63,12 @@ class Attention(nn.Module):
         B, C, H, W = q.shape
         q = q.reshape(B, C, H * W).permute(0, 2, 1)  # [B, HW, C]
         k = k.reshape(B, C, H * W)  # [B, C, HW]
-        energy = torch.bmm(q, k) * (C ** -0.5)
+        energy = torch.bmm(q, k) * (int(C) ** -0.5)
         attention = F.softmax(energy, dim=2)
 
         # Compute attention to which values
-        v = v.reshape(B, C, H * W).permute(0, 2, 1)  # [B, HW, C]
+        v = v.reshape(B, C, H * W)
+        attention = attention.permute(0, 2, 1)  # [B, HW, HW]
         out = torch.bmm(v, attention)
         out = out.reshape(B, C, H, W)
         out = self.proj(out)
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index fcf768b..f51e0a3 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,62 +1,69 @@
 """CNN decoder for the VQ-VAE."""
-import attr
+from typing import Sequence
+
 from torch import nn
 from torch import Tensor
 
 from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.norm import Normalize
 from text_recognizer.networks.vqvae.residual import Residual
 
 
-@attr.s(eq=False)
 class Decoder(nn.Module):
     """A CNN encoder network."""
 
-    in_channels: int = attr.ib()
-    embedding_dim: int = attr.ib()
-    out_channels: int = attr.ib()
-    res_channels: int = attr.ib()
-    num_residual_layers: int = attr.ib()
-    activation: str = attr.ib()
-    decoder: nn.Sequential = attr.ib(init=False)
-
-    def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
+    def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None:
         super().__init__()
+        self.out_channels = out_channels
+        self.hidden_dim = hidden_dim
+        self.channels_multipliers = tuple(channels_multipliers)
+        self.activation = activation
+        self.dropout_rate = dropout_rate
         self.decoder = self._build_decompression_block()
 
     def _build_decompression_block(self,) -> nn.Sequential:
+        in_channels = self.hidden_dim * self.channels_multipliers[0]
+        decoder = []
+        for _ in range(2):
+            decoder += [
+                Residual(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    dropout_rate=self.dropout_rate,
+                    use_norm=True,
+                ),
+            ]
+        
         activation_fn = activation_function(self.activation)
-        blocks = [
+        out_channels_multipliers = self.channels_multipliers + (1, )
+        num_blocks = len(self.channels_multipliers)
+
+        for i in range(num_blocks):
+            in_channels = self.hidden_dim * self.channels_multipliers[i]
+            out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+            decoder += [
+                nn.ConvTranspose2d(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    kernel_size=4,
+                    stride=2,
+                    padding=1,
+                ),
+                activation_fn,
+            ]
+
+        decoder += [
+            Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+            nn.Mish(),
             nn.Conv2d(
-                in_channels=self.in_channels,
-                out_channels=self.embedding_dim,
-                kernel_size=3,
-                padding=1,
-            )
-        ]
-        for _ in range(self.num_residual_layers):
-            blocks.append(
-                Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
-            )
-        blocks.append(activation_fn)
-        blocks += [
-            nn.ConvTranspose2d(
-                in_channels=self.embedding_dim,
-                out_channels=self.embedding_dim // 2,
-                kernel_size=4,
-                stride=2,
-                padding=1,
-            ),
-            activation_fn,
-            nn.ConvTranspose2d(
-                in_channels=self.embedding_dim // 2,
+                in_channels=self.hidden_dim * out_channels_multipliers[-1],
                 out_channels=self.out_channels,
-                kernel_size=4,
-                stride=2,
+                kernel_size=3,
+                stride=1,
                 padding=1,
             ),
         ]
-        return nn.Sequential(*blocks)
+        return nn.Sequential(*decoder)
 
     def forward(self, z_q: Tensor) -> Tensor:
         """Reconstruct input from given codes."""
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index f086c6b..ad8f950 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,7 +1,6 @@
 """CNN encoder for the VQ-VAE."""
-from typing import Sequence, Optional, Tuple, Type
+from typing import List, Tuple
 
-import attr
 from torch import nn
 from torch import Tensor
 
@@ -9,64 +8,59 @@ from text_recognizer.networks.util import activation_function
 from text_recognizer.networks.vqvae.residual import Residual
 
 
-@attr.s(eq=False)
 class Encoder(nn.Module):
     """A CNN encoder network."""
 
-    in_channels: int = attr.ib()
-    out_channels: int = attr.ib()
-    res_channels: int = attr.ib()
-    num_residual_layers: int = attr.ib()
-    embedding_dim: int = attr.ib()
-    activation: str = attr.ib()
-    encoder: nn.Sequential = attr.ib(init=False)
-
-    def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
+    def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None:
         super().__init__()
+        self.in_channels = in_channels
+        self.hidden_dim = hidden_dim
+        self.channels_multipliers = tuple(channels_multipliers)
+        self.activation = activation
+        self.dropout_rate = dropout_rate
         self.encoder = self._build_compression_block()
 
     def _build_compression_block(self) -> nn.Sequential:
-        activation_fn = activation_function(self.activation)
-        block = [
+        """Builds encoder network."""
+        encoder = [
             nn.Conv2d(
                 in_channels=self.in_channels,
-                out_channels=self.out_channels // 2,
-                kernel_size=4,
-                stride=2,
-                padding=1,
-            ),
-            activation_fn,
-            nn.Conv2d(
-                in_channels=self.out_channels // 2,
-                out_channels=self.out_channels,
-                kernel_size=4,
-                stride=2,
-                padding=1,
-            ),
-            activation_fn,
-            nn.Conv2d(
-                in_channels=self.out_channels,
-                out_channels=self.out_channels,
+                out_channels=self.hidden_dim,
                 kernel_size=3,
+                stride=1,
                 padding=1,
             ),
         ]
 
-        for _ in range(self.num_residual_layers):
-            block.append(
-                Residual(in_channels=self.out_channels, out_channels=self.res_channels)
-            )
+        num_blocks = len(self.channels_multipliers)
+        channels_multipliers = (1, ) + self.channels_multipliers
+        activation_fn = activation_function(self.activation)
 
-        block.append(
-            nn.Conv2d(
-                in_channels=self.out_channels,
-                out_channels=self.embedding_dim,
-                kernel_size=1,
-            )
-        )
+        for i in range(num_blocks):
+            in_channels = self.hidden_dim * channels_multipliers[i]
+            out_channels = self.hidden_dim * channels_multipliers[i + 1]
+            encoder += [
+                nn.Conv2d(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    kernel_size=4,
+                    stride=2,
+                    padding=1,
+                ),
+                activation_fn,
+            ]
+
+        for _ in range(2):
+            encoder += [
+                Residual(
+                    in_channels=self.hidden_dim * self.channels_multipliers[-1],
+                    out_channels=self.hidden_dim * self.channels_multipliers[-1],
+                    dropout_rate=self.dropout_rate,
+                    use_norm=True,
+                )
+            ]
 
-        return nn.Sequential(*block)
+        return nn.Sequential(*encoder)
 
     def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
         """Encodes input into a discrete representation."""
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
index df66efc..3e6963a 100644
--- a/text_recognizer/networks/vqvae/norm.py
+++ b/text_recognizer/networks/vqvae/norm.py
@@ -3,7 +3,7 @@ import attr
 from torch import nn, Tensor
 
 
-@attr.s
+@attr.s(eq=False)
 class Normalize(nn.Module):
     num_channels: int = attr.ib()
     norm: nn.GroupNorm = attr.ib(init=False)
@@ -12,7 +12,7 @@ class Normalize(nn.Module):
         """Post init configuration."""
         super().__init__()
         self.norm = nn.GroupNorm(
-            num_groups=32, num_channels=self.num_channels, eps=1.0e-6, affine=True
+            num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True
         )
 
     def forward(self, x: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
new file mode 100644
index 0000000..5c580df
--- /dev/null
+++ b/text_recognizer/networks/vqvae/pixelcnn.py
@@ -0,0 +1,165 @@
+"""PixelCNN encoder and decoder.
+
+Same as in VQGAN among other. Hopefully, better reconstructions...
+
+TODO: Add num of residual layers.
+"""
+from typing import Sequence
+
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae.attention import Attention
+from text_recognizer.networks.vqvae.norm import Normalize
+from text_recognizer.networks.vqvae.residual import Residual
+from text_recognizer.networks.vqvae.resize import Downsample, Upsample
+
+
+class Encoder(nn.Module):
+    """PixelCNN encoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        hidden_dim: int,
+        channels_multipliers: Sequence[int],
+        dropout_rate: float,
+    ) -> None:
+        super().__init__()
+        self.in_channels = in_channels
+        self.dropout_rate = dropout_rate
+        self.hidden_dim = hidden_dim
+        self.channels_multipliers = tuple(channels_multipliers)
+        self.encoder = self._build_encoder()
+
+    def _build_encoder(self) -> nn.Sequential:
+        """Builds encoder."""
+        encoder = [
+            nn.Conv2d(
+                in_channels=self.in_channels,
+                out_channels=self.hidden_dim,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+            ),
+        ]
+        num_blocks = len(self.channels_multipliers)
+        in_channels_multipliers = (1,) + self.channels_multipliers 
+        for i in range(num_blocks):
+            in_channels = self.hidden_dim * in_channels_multipliers[i]
+            out_channels = self.hidden_dim * self.channels_multipliers[i]
+            encoder += [
+                Residual(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    dropout_rate=self.dropout_rate,
+                    use_norm=True,
+                ),
+            ]
+            if i == num_blocks - 1:
+                encoder.append(Attention(in_channels=out_channels))
+            encoder.append(Downsample())
+
+        for _ in range(2):
+            encoder += [
+                Residual(
+                    in_channels=self.hidden_dim * self.channels_multipliers[-1],
+                    out_channels=self.hidden_dim * self.channels_multipliers[-1],
+                    dropout_rate=self.dropout_rate,
+                    use_norm=True,
+                ),
+                Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1])
+            ]
+
+        encoder += [
+            Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]),
+            nn.Mish(),
+            nn.Conv2d(
+                in_channels=self.hidden_dim * self.channels_multipliers[-1],
+                out_channels=self.hidden_dim * self.channels_multipliers[-1],
+                kernel_size=3,
+                stride=1,
+                padding=1,
+            ),
+        ]
+        return nn.Sequential(*encoder)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Encodes input to a latent representation."""
+        return self.encoder(x)
+
+
+class Decoder(nn.Module):
+    """PixelCNN decoder."""
+
+    def __init__(
+        self,
+        hidden_dim: int,
+        channels_multipliers: Sequence[int],
+        out_channels: int,
+        dropout_rate: float,
+    ) -> None:
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.out_channels = out_channels
+        self.channels_multipliers = tuple(channels_multipliers)
+        self.dropout_rate = dropout_rate
+        self.decoder = self._build_decoder()
+
+    def _build_decoder(self) -> nn.Sequential:
+        """Builds decoder."""
+        in_channels = self.hidden_dim * self.channels_multipliers[0]
+        decoder = [
+            Residual(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                dropout_rate=self.dropout_rate,
+                use_norm=True,
+            ),
+            Attention(in_channels=in_channels),
+            Residual(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                dropout_rate=self.dropout_rate,
+                use_norm=True,
+            ),
+        ]
+
+        out_channels_multipliers = self.channels_multipliers + (1, )
+        num_blocks = len(self.channels_multipliers)
+
+        for i in range(num_blocks):
+            in_channels = self.hidden_dim * self.channels_multipliers[i]
+            out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+            decoder.append(
+                Residual(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    dropout_rate=self.dropout_rate,
+                    use_norm=True,
+                )
+            )
+            if i == 0:
+                decoder.append(
+                    Attention(
+                        in_channels=out_channels
+                    )
+                )
+            decoder.append(Upsample())
+
+        decoder += [
+            Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+            nn.Mish(),
+            nn.Conv2d(
+                in_channels=self.hidden_dim * out_channels_multipliers[-1],
+                out_channels=self.out_channels,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+            ),
+        ]
+        return nn.Sequential(*decoder)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Decodes latent vector."""
+        return self.decoder(x)
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index a4f11f0..6fb57e8 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -11,13 +11,15 @@ import torch.nn.functional as F
 
 
 class EmbeddingEMA(nn.Module):
+    """Embedding for Exponential Moving Average (EMA)."""
+
     def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
         super().__init__()
         weight = torch.zeros(num_embeddings, embedding_dim)
         nn.init.kaiming_uniform_(weight, nonlinearity="linear")
         self.register_buffer("weight", weight)
-        self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
-        self.register_buffer("_weight_avg", weight)
+        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
+        self.register_buffer("weight_avg", weight.clone())
 
 
 class VectorQuantizer(nn.Module):
@@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module):
         return quantized_latent
 
     def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+        """Computes the EMA update to the codebook."""
         batch_cluster_size = one_hot_encoding.sum(axis=0)
         batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
-        self.embedding._cluster_size.data.mul_(self.decay).add_(
+        self.embedding.cluster_size.data.mul_(self.decay).add_(
             batch_cluster_size, alpha=1 - self.decay
         )
-        self.embedding._weight_avg.data.mul_(self.decay).add_(
+        self.embedding.weight_avg.data.mul_(self.decay).add_(
             batch_embedding_avg, alpha=1 - self.decay
         )
-        new_embedding = self.embedding._weight_avg / (
-            self.embedding._cluster_size + 1.0e-5
+        new_embedding = self.embedding.weight_avg / (
+            self.embedding.cluster_size + 1.0e-5
         ).unsqueeze(1)
         self.embedding.weight.data.copy_(new_embedding)
 
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
index 98109b8..4ed3781 100644
--- a/text_recognizer/networks/vqvae/residual.py
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -1,18 +1,55 @@
 """Residual block."""
+import attr
 from torch import nn
 from torch import Tensor
 
+from text_recognizer.networks.vqvae.norm import Normalize
 
+
+@attr.s(eq=False)
 class Residual(nn.Module):
-    def __init__(self, in_channels: int, out_channels: int,) -> None:
+    in_channels: int = attr.ib()
+    out_channels: int = attr.ib()
+    dropout_rate: float = attr.ib(default=0.0)
+    use_norm: bool = attr.ib(default=False)
+
+    def __attrs_post_init__(self) -> None:
+        """Post init configuration."""
         super().__init__()
-        self.block = nn.Sequential(
-            nn.Mish(inplace=True),
-            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
-            nn.Mish(inplace=True),
-            nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
-        )
+        self.block = self._build_res_block()
+        if self.in_channels != self.out_channels:
+            self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1)
+        else:
+            self.conv_shortcut = None
+
+    def _build_res_block(self) -> nn.Sequential:
+        """Build residual block."""
+        block = []
+        if self.use_norm:
+            block.append(Normalize(num_channels=self.in_channels))
+        block += [
+            nn.Mish(),
+            nn.Conv2d(
+                self.in_channels,
+                self.out_channels,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+        ]
+        if self.dropout_rate:
+            block += [nn.Dropout(p=self.dropout_rate)]
+
+        if self.use_norm:
+            block.append(Normalize(num_channels=self.out_channels))
+
+        block += [
+            nn.Mish(),
+            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False),
+        ]
+        return nn.Sequential(*block)
 
     def forward(self, x: Tensor) -> Tensor:
         """Apply the residual forward pass."""
-        return x + self.block(x)
+        residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x
+        return residual + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
index 769d089..8d67d02 100644
--- a/text_recognizer/networks/vqvae/resize.py
+++ b/text_recognizer/networks/vqvae/resize.py
@@ -8,7 +8,7 @@ class Upsample(nn.Module):
 
     def forward(self, x: Tensor) -> Tensor:
         """Applies upsampling."""
-        return F.interpolate(x, scale_factor=2, mode="nearest")
+        return F.interpolate(x, scale_factor=2.0, mode="nearest")
 
 
 class Downsample(nn.Module):
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 1585d40..0646119 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,13 +1,9 @@
 """The VQ-VAE."""
 from typing import Tuple
 
-import torch
 from torch import nn
 from torch import Tensor
-import torch.nn.functional as F
 
-from text_recognizer.networks.vqvae.decoder import Decoder
-from text_recognizer.networks.vqvae.encoder import Encoder
 from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
 
 
@@ -16,93 +12,45 @@ class VQVAE(nn.Module):
 
     def __init__(
         self,
-        in_channels: int,
-        res_channels: int,
-        num_residual_layers: int,
+        encoder: nn.Module,
+        decoder: nn.Module,
+        hidden_dim: int,
         embedding_dim: int,
         num_embeddings: int,
         decay: float = 0.99,
-        activation: str = "mish",
     ) -> None:
         super().__init__()
-        # Encoders
-        self.btm_encoder = Encoder(
-            in_channels=1,
-            out_channels=embedding_dim,
-            res_channels=res_channels,
-            num_residual_layers=num_residual_layers,
-            embedding_dim=embedding_dim,
-            activation=activation,
+        self.encoder = encoder
+        self.decoder = decoder
+        self.pre_codebook_conv = nn.Conv2d(
+            in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1
         )
-
-        self.top_encoder = Encoder(
-            in_channels=embedding_dim,
-            out_channels=embedding_dim,
-            res_channels=res_channels,
-            num_residual_layers=num_residual_layers,
-            embedding_dim=embedding_dim,
-            activation=activation,
+        self.post_codebook_conv = nn.Conv2d(
+            in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1
         )
-
-        # Quantizers
-        self.btm_quantizer = VectorQuantizer(
-            num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
-        )
-
-        self.top_quantizer = VectorQuantizer(
+        self.quantizer = VectorQuantizer(
             num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
         )
 
-        # Decoders
-        self.top_decoder = Decoder(
-            in_channels=embedding_dim,
-            out_channels=embedding_dim,
-            embedding_dim=embedding_dim,
-            res_channels=res_channels,
-            num_residual_layers=num_residual_layers,
-            activation=activation,
-        )
-
-        self.btm_decoder = Decoder(
-            in_channels=2 * embedding_dim,
-            out_channels=in_channels,
-            embedding_dim=embedding_dim,
-            res_channels=res_channels,
-            num_residual_layers=num_residual_layers,
-            activation=activation,
-        )
 
-    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+    def encode(self, x: Tensor) -> Tensor:
         """Encodes input to a latent code."""
-        z_btm = self.btm_encoder(x)
-        z_top = self.top_encoder(z_btm)
-        return z_btm, z_top
+        z_e = self.encoder(x)
+        return self.pre_codebook_conv(z_e)
 
-    def quantize(
-        self, z_btm: Tensor, z_top: Tensor
-    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
-        q_btm, vq_btm_loss = self.top_quantizer(z_btm)
-        q_top, vq_top_loss = self.top_quantizer(z_top)
-        return q_btm, vq_btm_loss, q_top, vq_top_loss
+    def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
+        z_q, vq_loss = self.quantizer(z_e)
+        return z_q, vq_loss
 
-    def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
+    def decode(self, z_q: Tensor) -> Tensor:
         """Reconstructs input from latent codes."""
-        d_top = self.top_decoder(q_top)
-        x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
-        return d_top, x_hat
-
-    def loss_fn(
-        self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
-    ) -> Tensor:
-        """Calculates the latent loss."""
-        return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
+        z = self.post_codebook_conv(z_q)
+        x_hat = self.decoder(z)
+        return x_hat
 
     def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
         """Compresses and decompresses input."""
-        z_btm, z_top = self.encode(x)
-        q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
-        d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
-        vq_loss = self.loss_fn(
-            vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
-        )
+        z_e = self.encode(x)
+        z_q, vq_loss = self.quantize(z_e)
+        x_hat = self.decode(z_q)
         return x_hat, vq_loss
-- 
cgit v1.2.3-70-g09d2