From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- text_recognizer/networks/vqvae/__init__.py | 5 + text_recognizer/networks/vqvae/decoder.py | 133 +++++++++++++++++++ text_recognizer/networks/vqvae/encoder.py | 147 +++++++++++++++++++++ text_recognizer/networks/vqvae/vector_quantizer.py | 119 +++++++++++++++++ text_recognizer/networks/vqvae/vqvae.py | 74 +++++++++++ 5 files changed, 478 insertions(+) create mode 100644 text_recognizer/networks/vqvae/__init__.py create mode 100644 text_recognizer/networks/vqvae/decoder.py create mode 100644 text_recognizer/networks/vqvae/encoder.py create mode 100644 text_recognizer/networks/vqvae/vector_quantizer.py create mode 100644 text_recognizer/networks/vqvae/vqvae.py (limited to 'text_recognizer/networks/vqvae') diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py new file mode 100644 index 0000000..763953c --- /dev/null +++ b/text_recognizer/networks/vqvae/__init__.py @@ -0,0 +1,5 @@ +"""VQ-VAE module.""" +from .decoder import Decoder +from .encoder import Encoder +from .vector_quantizer import VectorQuantizer +from .vqvae import VQVAE diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py new file mode 100644 index 0000000..8847aba --- /dev/null +++ b/text_recognizer/networks/vqvae/decoder.py @@ -0,0 +1,133 @@ +"""CNN decoder for the VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class Decoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + upsampling: Optional[List[List[int]]] = None, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.upsampling = upsampling + + self.res_block = nn.ModuleList([]) + self.upsampling_block = nn.ModuleList([]) + + self.embedding_dim = embedding_dim + activation = activation_function(activation) + + # Configure encoder. + self.decoder = self._build_decoder( + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + ) + + def _build_decompression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + activation, + ) + ) + + if i < len(self.upsampling): + modules.append(nn.Upsample(size=self.upsampling[i]),) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + modules.extend( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 + ), + nn.Tanh(), + ) + ) + + return modules + + def _build_decoder( + self, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + + self.res_block.append( + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + ) + + # Bottleneck module. + self.res_block.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[0], channels[0], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + # Decompression module + self.upsampling_block.extend( + self._build_decompression_block( + channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ) + ) + + self.res_block = nn.Sequential(*self.res_block) + self.upsampling_block = nn.Sequential(*self.upsampling_block) + + return nn.Sequential(self.res_block, self.upsampling_block) + + def forward(self, z_q: Tensor) -> Tensor: + """Reconstruct input from given codes.""" + x_reconstruction = self.decoder(z_q) + return x_reconstruction diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py new file mode 100644 index 0000000..d3adac5 --- /dev/null +++ b/text_recognizer/networks/vqvae/encoder.py @@ -0,0 +1,147 @@ +"""CNN encoder for the VQ-VAE.""" +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer + + +class _ResidualBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], + ) -> None: + super().__init__() + self.block = [ + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), + ] + + if dropout is not None: + self.block.append(dropout) + + self.block = nn.Sequential(*self.block) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) + + +class Encoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = beta + activation = activation_function(activation) + + # Configure encoder. + self.encoder = self._build_encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, + ) + + # Configure Vector Quantizer. + self.vector_quantizer = VectorQuantizer( + self.num_embeddings, self.embedding_dim, self.beta + ) + + def _build_compression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for out_channels, kernel_size, stride in configuration: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ), + activation, + ) + ) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + return modules + + def _build_encoder( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + encoder = nn.ModuleList([]) + + # compression module + encoder.extend( + self._build_compression_block( + in_channels, channels, kernel_sizes, strides, activation, dropout + ) + ) + + # Bottleneck module. + encoder.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[-1], channels[-1], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + encoder.append( + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + ) + + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input into a discrete representation.""" + z_e = self.encoder(x) + z_q, vq_loss = self.vector_quantizer(z_e) + return z_q, vq_loss diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py new file mode 100644 index 0000000..f92c7ee --- /dev/null +++ b/text_recognizer/networks/vqvae/vector_quantizer.py @@ -0,0 +1,119 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py + +""" + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +class VectorQuantizer(nn.Module): + """The codebook that contains quantized vectors.""" + + def __init__( + self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 + ) -> None: + super().__init__() + self.K = num_embeddings + self.D = embedding_dim + self.beta = beta + + self.embedding = nn.Embedding(self.K, self.D) + + # Initialize the codebook. + nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) + + def discretization_bottleneck(self, latent: Tensor) -> Tensor: + """Computes the code nearest to the latent representation. + + First we compute the posterior categorical distribution, and then map + the latent representation to the nearest element of the embedding. + + Args: + latent (Tensor): The latent representation. + + Shape: + - latent :math:`(B x H x W, D)` + + Returns: + Tensor: The quantized embedding vector. + + """ + # Store latent shape. + b, h, w, d = latent.shape + + # Flatten the latent representation to 2D. + latent = rearrange(latent, "b h w d -> (b h w) d") + + # Compute the L2 distance between the latents and the embeddings. + l2_distance = ( + torch.sum(latent ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * latent @ self.embedding.weight.t() + ) # [BHW x K] + + # Find the embedding k nearest to each latent. + encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] + + # Convert to one-hot encodings, aka discrete bottleneck. + one_hot_encoding = torch.zeros( + encoding_indices.shape[0], self.K, device=latent.device + ) + one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] + + # Embedding quantization. + quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] + quantized_latent = rearrange( + quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w + ) + + return quantized_latent + + def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + """Vector Quantization loss. + + The vector quantization algorithm allows us to create a codebook. The VQ + algorithm works by moving the embedding vectors towards the encoder outputs. + + The embedding loss moves the embedding vector towards the encoder outputs. The + .detach() works as the stop gradient (sg) described in the paper. + + Because the volume of the embedding space is dimensionless, it can arbitarily + grow if the embeddings are not trained as fast as the encoder parameters. To + mitigate this, a commitment loss is added in the second term which makes sure + that the encoder commits to an embedding and that its output does not grow. + + Args: + latent (Tensor): The encoder output. + quantized_latent (Tensor): The quantized latent. + + Returns: + Tensor: The combinded VQ loss. + + """ + embedding_loss = F.mse_loss(quantized_latent, latent.detach()) + commitment_loss = F.mse_loss(quantized_latent.detach(), latent) + return embedding_loss + self.beta * commitment_loss + + def forward(self, latent: Tensor) -> Tensor: + """Forward pass that returns the quantized vector and the vq loss.""" + # Rearrange latent representation s.t. the hidden dim is at the end. + latent = rearrange(latent, "b d h w -> b h w d") + + # Maps latent to the nearest code in the codebook. + quantized_latent = self.discretization_bottleneck(latent) + + loss = self.vq_loss(latent, quantized_latent) + + # Add residue to the quantized latent. + quantized_latent = latent + (quantized_latent - latent).detach() + + # Rearrange the quantized shape back to the original shape. + quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") + + return quantized_latent, loss diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss -- cgit v1.2.3-70-g09d2