From d691b548cd0b6fc4ea184d64261f633789fee021 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 7 Jan 2021 23:35:42 +0100 Subject: working on vq-vae --- src/text_recognizer/networks/cnn_transformer.py | 12 --- src/text_recognizer/networks/vqvae/__init__.py | 1 + src/text_recognizer/networks/vqvae/encoder.py | 64 +++++++++++ .../networks/vqvae/vector_quantizer.py | 119 +++++++++++++++++++++ 4 files changed, 184 insertions(+), 12 deletions(-) create mode 100644 src/text_recognizer/networks/vqvae/__init__.py create mode 100644 src/text_recognizer/networks/vqvae/encoder.py create mode 100644 src/text_recognizer/networks/vqvae/vector_quantizer.py (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index caa73e3..43e5403 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -109,18 +109,6 @@ class CNNTransformer(nn.Module): b, t, _ = src.shape - # Insert sos and eos token. - # sos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 2]).long().to(src.device) - # ) - # eos_token = self.character_embedding( - # torch.Tensor([self.vocab_size - 1]).long().to(src.device) - # ) - - # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) - # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) - # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) - # src = torch.cat((sos_tokens, src), dim=1) src += self.src_position_embedding[:, :t] return src diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py new file mode 100644 index 0000000..e1f05fa --- /dev/null +++ b/src/text_recognizer/networks/vqvae/__init__.py @@ -0,0 +1 @@ +"""VQ-VAE module.""" diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py new file mode 100644 index 0000000..60c4c43 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/encoder.py @@ -0,0 +1,64 @@ +"""CNN encoder for the VQ-VAE.""" + +from typing import List, Optional, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer + + +class _ResidualBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> None: + super().__init__() + self.block = [ + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + activation, + nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), + ] + + if dropout is not None: + self.block.append(dropout) + + self.block = nn.Sequential(*self.block) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) + + +class Encoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + beta: float = 0.25, + activation: str = "elu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + pass + # if dropout_rate: + # if activation == "selu": + # dropout = nn.AlphaDropout(p=dropout_rate) + # else: + # dropout = nn.Dropout(p=dropout_rate) + # else: + # dropout = None + + def _build_encoder(self) -> nn.Sequential: + # TODO: Continue to implement encoder. + pass diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py new file mode 100644 index 0000000..25e5583 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py @@ -0,0 +1,119 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py + +""" + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +class VectorQuantizer(nn.Module): + """The codebook that contains quantized vectors.""" + + def __init__( + self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 + ) -> None: + super().__init__() + self.K = num_embeddings + self.D = embedding_dim + self.beta = beta + + self.embedding = nn.Embedding(self.K, self.D) + + # Initialize the codebook. + self.embedding.weight.uniform_(-1 / self.K, 1 / self.K) + + def discretization_bottleneck(self, latent: Tensor) -> Tensor: + """Computes the code nearest to the latent representation. + + First we compute the posterior categorical distribution, and then map + the latent representation to the nearest element of the embedding. + + Args: + latent (Tensor): The latent representation. + + Shape: + - latent :math:`(B x H x W, D)` + + Returns: + Tensor: The quantized embedding vector. + + """ + # Store latent shape. + b, h, w, d = latent.shape + + # Flatten the latent representation to 2D. + latent = rearrange(latent, "b h w d -> (b h w) d") + + # Compute the L2 distance between the latents and the embeddings. + l2_distance = ( + torch.sum(latent ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * latent @ self.embedding.weight.t() + ) # [BHW x K] + + # Find the embedding k nearest to each latent. + encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] + + # Convert to one-hot encodings, aka discrete bottleneck. + one_hot_encoding = torch.zeros( + encoding_indices.shape[0], self.K, device=latent.device + ) + one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] + + # Embedding quantization. + quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] + quantized_latent = rearrange( + quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w + ) + + return quantized_latent + + def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + """Vector Quantization loss. + + The vector quantization algorithm allows us to create a codebook. The VQ + algorithm works by moving the embedding vectors towards the encoder outputs. + + The embedding loss moves the embedding vector towards the encoder outputs. The + .detach() works as the stop gradient (sg) described in the paper. + + Because the volume of the embedding space is dimensionless, it can arbitarily + grow if the embeddings are not trained as fast as the encoder parameters. To + mitigate this, a commitment loss is added in the second term which makes sure + that the encoder commits to an embedding and that its output does not grow. + + Args: + latent (Tensor): The encoder output. + quantized_latent (Tensor): The quantized latent. + + Returns: + Tensor: The combinded VQ loss. + + """ + embedding_loss = F.mse_loss(quantized_latent, latent.detach()) + commitment_loss = F.mse_loss(quantized_latent.detach(), latent) + return embedding_loss + self.beta * commitment_loss + + def forward(self, latent: Tensor) -> Tensor: + """Forward pass that returns the quantized vector and the vq loss.""" + # Rearrange latent representation s.t. the hidden dim is at the end. + latent = rearrange(latent, "b d h w -> b h w d") + + # Maps latent to the nearest code in the codebook. + quantized_latent = self.discretization_bottleneck(latent) + + loss = self.vq_loss(latent, quantized_latent) + + # Add residue to the quantized latent. + quantized_latent = latent + (quantized_latent - latent).detach() + + # Rearrange the quantized shape back to the original shape. + quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") + + return quantized_latent, loss -- cgit v1.2.3-70-g09d2