diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 122 |
1 files changed, 78 insertions, 44 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 5aa929b..1585d40 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,10 +1,14 @@ """The VQ-VAE.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple +import torch from torch import nn from torch import Tensor +import torch.nn.functional as F -from text_recognizer.networks.vqvae import Decoder, Encoder +from text_recognizer.networks.vqvae.decoder import Decoder +from text_recognizer.networks.vqvae.encoder import Encoder +from text_recognizer.networks.vqvae.quantizer import VectorQuantizer class VQVAE(nn.Module): @@ -13,62 +17,92 @@ class VQVAE(nn.Module): def __init__( self, in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], + res_channels: int, num_residual_layers: int, embedding_dim: int, num_embeddings: int, - upsampling: Optional[List[List[int]]] = None, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - *args: Any, - **kwargs: Dict, + decay: float = 0.99, + activation: str = "mish", ) -> None: super().__init__() + # Encoders + self.btm_encoder = Encoder( + in_channels=1, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + self.top_encoder = Encoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + # Quantizers + self.btm_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, + ) - # configure encoder. - self.encoder = Encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - num_embeddings, - beta, - activation, - dropout_rate, + self.top_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - # Configure decoder. - channels.reverse() - kernel_sizes.reverse() - strides.reverse() - self.decoder = Decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - upsampling, - activation, - dropout_rate, + # Decoders + self.top_decoder = Decoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, + ) + + self.btm_decoder = Decoder( + in_channels=2 * embedding_dim, + out_channels=in_channels, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, ) def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input to a latent code.""" - return self.encoder(x) + z_btm = self.btm_encoder(x) + z_top = self.top_encoder(z_btm) + return z_btm, z_top + + def quantize( + self, z_btm: Tensor, z_top: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + q_btm, vq_btm_loss = self.top_quantizer(z_btm) + q_top, vq_top_loss = self.top_quantizer(z_top) + return q_btm, vq_btm_loss, q_top, vq_top_loss - def decode(self, z_q: Tensor) -> Tensor: + def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]: """Reconstructs input from latent codes.""" - return self.decoder(z_q) + d_top = self.top_decoder(q_top) + x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1)) + return d_top, x_hat + + def loss_fn( + self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor + ) -> Tensor: + """Calculates the latent loss.""" + return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - z_q, vq_loss = self.encode(x) - x_reconstruction = self.decode(z_q) - return x_reconstruction, vq_loss + z_btm, z_top = self.encode(x) + q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top) + d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top) + vq_loss = self.loss_fn( + vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm + ) + return x_hat, vq_loss |