From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 6 Aug 2021 02:42:45 +0200 Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading --- text_recognizer/networks/vqvae/vqvae.py | 98 ++++++++------------------------- 1 file changed, 23 insertions(+), 75 deletions(-) (limited to 'text_recognizer/networks/vqvae/vqvae.py') diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 1585d40..0646119 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,13 +1,9 @@ """The VQ-VAE.""" from typing import Tuple -import torch from torch import nn from torch import Tensor -import torch.nn.functional as F -from text_recognizer.networks.vqvae.decoder import Decoder -from text_recognizer.networks.vqvae.encoder import Encoder from text_recognizer.networks.vqvae.quantizer import VectorQuantizer @@ -16,93 +12,45 @@ class VQVAE(nn.Module): def __init__( self, - in_channels: int, - res_channels: int, - num_residual_layers: int, + encoder: nn.Module, + decoder: nn.Module, + hidden_dim: int, embedding_dim: int, num_embeddings: int, decay: float = 0.99, - activation: str = "mish", ) -> None: super().__init__() - # Encoders - self.btm_encoder = Encoder( - in_channels=1, - out_channels=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - embedding_dim=embedding_dim, - activation=activation, + self.encoder = encoder + self.decoder = decoder + self.pre_codebook_conv = nn.Conv2d( + in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1 ) - - self.top_encoder = Encoder( - in_channels=embedding_dim, - out_channels=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - embedding_dim=embedding_dim, - activation=activation, + self.post_codebook_conv = nn.Conv2d( + in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1 ) - - # Quantizers - self.btm_quantizer = VectorQuantizer( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, - ) - - self.top_quantizer = VectorQuantizer( + self.quantizer = VectorQuantizer( num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - # Decoders - self.top_decoder = Decoder( - in_channels=embedding_dim, - out_channels=embedding_dim, - embedding_dim=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - activation=activation, - ) - - self.btm_decoder = Decoder( - in_channels=2 * embedding_dim, - out_channels=in_channels, - embedding_dim=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - activation=activation, - ) - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def encode(self, x: Tensor) -> Tensor: """Encodes input to a latent code.""" - z_btm = self.btm_encoder(x) - z_top = self.top_encoder(z_btm) - return z_btm, z_top + z_e = self.encoder(x) + return self.pre_codebook_conv(z_e) - def quantize( - self, z_btm: Tensor, z_top: Tensor - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - q_btm, vq_btm_loss = self.top_quantizer(z_btm) - q_top, vq_top_loss = self.top_quantizer(z_top) - return q_btm, vq_btm_loss, q_top, vq_top_loss + def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: + z_q, vq_loss = self.quantizer(z_e) + return z_q, vq_loss - def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]: + def decode(self, z_q: Tensor) -> Tensor: """Reconstructs input from latent codes.""" - d_top = self.top_decoder(q_top) - x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1)) - return d_top, x_hat - - def loss_fn( - self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor - ) -> Tensor: - """Calculates the latent loss.""" - return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm) + z = self.post_codebook_conv(z_q) + x_hat = self.decoder(z) + return x_hat def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" - z_btm, z_top = self.encode(x) - q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top) - d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top) - vq_loss = self.loss_fn( - vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm - ) + z_e = self.encode(x) + z_q, vq_loss = self.quantize(z_e) + x_hat = self.decode(z_q) return x_hat, vq_loss -- cgit v1.2.3-70-g09d2