summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/__init__.py5
-rw-r--r--text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--text_recognizer/networks/vqvae/encoder.py147
-rw-r--r--text_recognizer/networks/vqvae/vector_quantizer.py119
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py74
5 files changed, 478 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
new file mode 100644
index 0000000..763953c
--- /dev/null
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -0,0 +1,5 @@
+"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
new file mode 100644
index 0000000..d3adac5
--- /dev/null
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -0,0 +1,147 @@
+"""CNN encoder for the VQ-VAE."""
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
+
+
+class _ResidualBlock(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ ) -> None:
+ super().__init__()
+ self.block = [
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
+ ]
+
+ if dropout is not None:
+ self.block.append(dropout)
+
+ self.block = nn.Sequential(*self.block)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the residual forward pass."""
+ return x + self.block(x)
+
+
+class Encoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py
new file mode 100644
index 0000000..f92c7ee
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -0,0 +1,119 @@
+"""Implementation of a Vector Quantized Variational AutoEncoder.
+
+Reference:
+https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
+
+"""
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizer(nn.Module):
+ """The codebook that contains quantized vectors."""
+
+ def __init__(
+ self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+ ) -> None:
+ super().__init__()
+ self.K = num_embeddings
+ self.D = embedding_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.K, self.D)
+
+ # Initialize the codebook.
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
+
+ def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+ """Computes the code nearest to the latent representation.
+
+ First we compute the posterior categorical distribution, and then map
+ the latent representation to the nearest element of the embedding.
+
+ Args:
+ latent (Tensor): The latent representation.
+
+ Shape:
+ - latent :math:`(B x H x W, D)`
+
+ Returns:
+ Tensor: The quantized embedding vector.
+
+ """
+ # Store latent shape.
+ b, h, w, d = latent.shape
+
+ # Flatten the latent representation to 2D.
+ latent = rearrange(latent, "b h w d -> (b h w) d")
+
+ # Compute the L2 distance between the latents and the embeddings.
+ l2_distance = (
+ torch.sum(latent ** 2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight ** 2, dim=1)
+ - 2 * latent @ self.embedding.weight.t()
+ ) # [BHW x K]
+
+ # Find the embedding k nearest to each latent.
+ encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1]
+
+ # Convert to one-hot encodings, aka discrete bottleneck.
+ one_hot_encoding = torch.zeros(
+ encoding_indices.shape[0], self.K, device=latent.device
+ )
+ one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
+
+ # Embedding quantization.
+ quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D]
+ quantized_latent = rearrange(
+ quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
+ )
+
+ return quantized_latent
+
+ def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+ """Vector Quantization loss.
+
+ The vector quantization algorithm allows us to create a codebook. The VQ
+ algorithm works by moving the embedding vectors towards the encoder outputs.
+
+ The embedding loss moves the embedding vector towards the encoder outputs. The
+ .detach() works as the stop gradient (sg) described in the paper.
+
+ Because the volume of the embedding space is dimensionless, it can arbitarily
+ grow if the embeddings are not trained as fast as the encoder parameters. To
+ mitigate this, a commitment loss is added in the second term which makes sure
+ that the encoder commits to an embedding and that its output does not grow.
+
+ Args:
+ latent (Tensor): The encoder output.
+ quantized_latent (Tensor): The quantized latent.
+
+ Returns:
+ Tensor: The combinded VQ loss.
+
+ """
+ embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+ commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
+ return embedding_loss + self.beta * commitment_loss
+
+ def forward(self, latent: Tensor) -> Tensor:
+ """Forward pass that returns the quantized vector and the vq loss."""
+ # Rearrange latent representation s.t. the hidden dim is at the end.
+ latent = rearrange(latent, "b d h w -> b h w d")
+
+ # Maps latent to the nearest code in the codebook.
+ quantized_latent = self.discretization_bottleneck(latent)
+
+ loss = self.vq_loss(latent, quantized_latent)
+
+ # Add residue to the quantized latent.
+ quantized_latent = latent + (quantized_latent - latent).detach()
+
+ # Rearrange the quantized shape back to the original shape.
+ quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w")
+
+ return quantized_latent, loss
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss