summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 23:35:42 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 23:35:42 +0100
commitd691b548cd0b6fc4ea184d64261f633789fee021 (patch)
tree99e2fc5481ce102d5655b65681274e5f0286306f /src/text_recognizer/networks/vqvae
parentff9a21d333f11a42e67c1963ed67de9c0fda87c9 (diff)
working on vq-vae
Diffstat (limited to 'src/text_recognizer/networks/vqvae')
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py64
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py119
3 files changed, 184 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
new file mode 100644
index 0000000..e1f05fa
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -0,0 +1 @@
+"""VQ-VAE module."""
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
new file mode 100644
index 0000000..60c4c43
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -0,0 +1,64 @@
+"""CNN encoder for the VQ-VAE."""
+
+from typing import List, Optional, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
+
+
+class _ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> None:
+ super().__init__()
+ self.block = [
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ activation,
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
+ ]
+
+ if dropout is not None:
+ self.block.append(dropout)
+
+ self.block = nn.Sequential(*self.block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the residual forward pass."""
+ return x + self.block(x)
+
+
+class Encoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ beta: float = 0.25,
+ activation: str = "elu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+ pass
+ # if dropout_rate:
+ # if activation == "selu":
+ # dropout = nn.AlphaDropout(p=dropout_rate)
+ # else:
+ # dropout = nn.Dropout(p=dropout_rate)
+ # else:
+ # dropout = None
+
+ def _build_encoder(self) -> nn.Sequential:
+ # TODO: Continue to implement encoder.
+ pass
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
new file mode 100644
index 0000000..25e5583
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -0,0 +1,119 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+
+"""
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizer(nn.Module):
+ """The codebook that contains quantized vectors."""
+
+ def __init__(
+ self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+ ) -> None:
+ super().__init__()
+ self.K = num_embeddings
+ self.D = embedding_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.K, self.D)
+
+ # Initialize the codebook.
+ self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+
+ def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+ """Computes the code nearest to the latent representation.
+
+ First we compute the posterior categorical distribution, and then map
+ the latent representation to the nearest element of the embedding.
+
+ Args:
+ latent (Tensor): The latent representation.
+
+ Shape:
+ - latent :math:`(B x H x W, D)`
+
+ Returns:
+ Tensor: The quantized embedding vector.
+
+ """
+ # Store latent shape.
+ b, h, w, d = latent.shape
+
+ # Flatten the latent representation to 2D.
+ latent = rearrange(latent, "b h w d -> (b h w) d")
+
+ # Compute the L2 distance between the latents and the embeddings.
+ l2_distance = (
+ torch.sum(latent ** 2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight ** 2, dim=1)
+ - 2 * latent @ self.embedding.weight.t()
+ ) # [BHW x K]
+
+ # Find the embedding k nearest to each latent.
+ encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
+
+ # Convert to one-hot encodings, aka discrete bottleneck.
+ one_hot_encoding = torch.zeros(
+ encoding_indices.shape[0], self.K, device=latent.device
+ )
+ one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
+
+ # Embedding quantization.
+ quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
+ quantized_latent = rearrange(
+ quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
+ )
+
+ return quantized_latent
+
+ def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+ """Vector Quantization loss.
+
+ The vector quantization algorithm allows us to create a codebook. The VQ
+ algorithm works by moving the embedding vectors towards the encoder outputs.
+
+ The embedding loss moves the embedding vector towards the encoder outputs. The
+ .detach() works as the stop gradient (sg) described in the paper.
+
+ Because the volume of the embedding space is dimensionless, it can arbitarily
+ grow if the embeddings are not trained as fast as the encoder parameters. To
+ mitigate this, a commitment loss is added in the second term which makes sure
+ that the encoder commits to an embedding and that its output does not grow.
+
+ Args:
+ latent (Tensor): The encoder output.
+ quantized_latent (Tensor): The quantized latent.
+
+ Returns:
+ Tensor: The combinded VQ loss.
+
+ """
+ embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+ commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
+ return embedding_loss + self.beta * commitment_loss
+
+ def forward(self, latent: Tensor) -> Tensor:
+ """Forward pass that returns the quantized vector and the vq loss."""
+ # Rearrange latent representation s.t. the hidden dim is at the end.
+ latent = rearrange(latent, "b d h w -> b h w d")
+
+ # Maps latent to the nearest code in the codebook.
+ quantized_latent = self.discretization_bottleneck(latent)
+
+ loss = self.vq_loss(latent, quantized_latent)
+
+ # Add residue to the quantized latent.
+ quantized_latent = latent + (quantized_latent - latent).detach()
+
+ # Rearrange the quantized shape back to the original shape.
+ quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
+
+ return quantized_latent, loss