From d691b548cd0b6fc4ea184d64261f633789fee021 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Thu, 7 Jan 2021 23:35:42 +0100
Subject: working on vq-vae

---
 src/text_recognizer/networks/cnn_transformer.py    |  12 ---
 src/text_recognizer/networks/vqvae/__init__.py     |   1 +
 src/text_recognizer/networks/vqvae/encoder.py      |  64 +++++++++++
 .../networks/vqvae/vector_quantizer.py             | 119 +++++++++++++++++++++
 4 files changed, 184 insertions(+), 12 deletions(-)
 create mode 100644 src/text_recognizer/networks/vqvae/__init__.py
 create mode 100644 src/text_recognizer/networks/vqvae/encoder.py
 create mode 100644 src/text_recognizer/networks/vqvae/vector_quantizer.py

(limited to 'src/text_recognizer/networks')

diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index caa73e3..43e5403 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -109,18 +109,6 @@ class CNNTransformer(nn.Module):
 
         b, t, _ = src.shape
 
-        # Insert sos and eos token.
-        # sos_token = self.character_embedding(
-        #    torch.Tensor([self.vocab_size - 2]).long().to(src.device)
-        # )
-        # eos_token = self.character_embedding(
-        #    torch.Tensor([self.vocab_size - 1]).long().to(src.device)
-        # )
-
-        # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
-        # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
-        # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
-        # src = torch.cat((sos_tokens, src), dim=1)
         src += self.src_position_embedding[:, :t]
 
         return src
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
new file mode 100644
index 0000000..e1f05fa
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -0,0 +1 @@
+"""VQ-VAE module."""
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
new file mode 100644
index 0000000..60c4c43
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -0,0 +1,64 @@
+"""CNN encoder for the VQ-VAE."""
+
+from typing import List, Optional, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
+
+
+class _ResidualBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        activation: Type[nn.Module],
+        dropout: Optional[Type[nn.Module]],
+    ) -> None:
+        super().__init__()
+        self.block = [
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+            activation,
+            nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
+        ]
+
+        if dropout is not None:
+            self.block.append(dropout)
+
+        self.block = nn.Sequential(*self.block)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Apply the residual forward pass."""
+        return x + self.block(x)
+
+
+class Encoder(nn.Module):
+    """A CNN encoder network."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        channels: List[int],
+        num_residual_layers: int,
+        embedding_dim: int,
+        num_embeddings: int,
+        beta: float = 0.25,
+        activation: str = "elu",
+        dropout_rate: float = 0.0,
+    ) -> None:
+        super().__init__()
+        pass
+        # if dropout_rate:
+        #   if activation == "selu":
+        #      dropout = nn.AlphaDropout(p=dropout_rate)
+        # else:
+        #       dropout = nn.Dropout(p=dropout_rate)
+        # else:
+        #   dropout = None
+
+    def _build_encoder(self) -> nn.Sequential:
+        # TODO: Continue to implement encoder.
+        pass
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
new file mode 100644
index 0000000..25e5583
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -0,0 +1,119 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+
+"""
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizer(nn.Module):
+    """The codebook that contains quantized vectors."""
+
+    def __init__(
+        self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+    ) -> None:
+        super().__init__()
+        self.K = num_embeddings
+        self.D = embedding_dim
+        self.beta = beta
+
+        self.embedding = nn.Embedding(self.K, self.D)
+
+        # Initialize the codebook.
+        self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+
+    def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+        """Computes the code nearest to the latent representation.
+
+        First we compute the posterior categorical distribution, and then map
+        the latent representation to the nearest element of the embedding.
+
+        Args:
+            latent (Tensor): The latent representation.
+
+        Shape:
+            - latent :math:`(B x H x W, D)`
+
+        Returns:
+            Tensor: The quantized embedding vector.
+
+        """
+        # Store latent shape.
+        b, h, w, d = latent.shape
+
+        # Flatten the latent representation to 2D.
+        latent = rearrange(latent, "b h w d -> (b h w) d")
+
+        # Compute the L2 distance between the latents and the embeddings.
+        l2_distance = (
+            torch.sum(latent ** 2, dim=1, keepdim=True)
+            + torch.sum(self.embedding.weight ** 2, dim=1)
+            - 2 * latent @ self.embedding.weight.t()
+        )  # [BHW x K]
+
+        # Find the embedding k nearest to each latent.
+        encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1)  # [BHW, 1]
+
+        # Convert to one-hot encodings, aka discrete bottleneck.
+        one_hot_encoding = torch.zeros(
+            encoding_indices.shape[0], self.K, device=latent.device
+        )
+        one_hot_encoding.scatter_(1, encoding_indices, 1)  # [BHW x K]
+
+        # Embedding quantization.
+        quantized_latent = one_hot_encoding @ self.embedding.weight  # [BHW, D]
+        quantized_latent = rearrange(
+            quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
+        )
+
+        return quantized_latent
+
+    def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+        """Vector Quantization loss.
+
+        The vector quantization algorithm allows us to create a codebook. The VQ
+        algorithm works by moving the embedding vectors towards the encoder outputs.
+
+        The embedding loss moves the embedding vector towards the encoder outputs. The
+        .detach() works as the stop gradient (sg) described in the paper.
+
+        Because the volume of the embedding space is dimensionless, it can arbitarily
+        grow if the embeddings are not trained as fast as the encoder parameters. To
+        mitigate this, a commitment loss is added in the second term which makes sure
+        that the encoder commits to an embedding and that its output does not grow.
+
+        Args:
+            latent (Tensor): The encoder output.
+            quantized_latent (Tensor): The quantized latent.
+
+        Returns:
+            Tensor: The combinded VQ loss.
+
+        """
+        embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+        commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
+        return embedding_loss + self.beta * commitment_loss
+
+    def forward(self, latent: Tensor) -> Tensor:
+        """Forward pass that returns the quantized vector and the vq loss."""
+        # Rearrange latent representation s.t. the hidden dim is at the end.
+        latent = rearrange(latent, "b d h w -> b h w d")
+
+        # Maps latent to the nearest code in the codebook.
+        quantized_latent = self.discretization_bottleneck(latent)
+
+        loss = self.vq_loss(latent, quantized_latent)
+
+        # Add residue to the quantized latent.
+        quantized_latent = latent + (quantized_latent - latent).detach()
+
+        # Rearrange the quantized shape back to the original shape.
+        quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
+
+        return quantized_latent, loss
-- 
cgit v1.2.3-70-g09d2