From 700ce6ed83867601de0ae55032afdd5e12438258 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 17 Nov 2021 22:43:12 +0100 Subject: Update vqvae with new quantizer --- text_recognizer/networks/vqvae/vqvae.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) (limited to 'text_recognizer/networks/vqvae') diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index d876ca1..5560e12 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -4,7 +4,7 @@ from typing import Tuple from torch import nn from torch import Tensor -from text_recognizer.networks.vqvae.quantizer import VectorQuantizer +from text_recognizer.networks.quantizer.quantizer import VectorQuantizer class VQVAE(nn.Module): @@ -14,39 +14,25 @@ class VQVAE(nn.Module): self, encoder: nn.Module, decoder: nn.Module, - hidden_dim: int, - embedding_dim: int, - num_embeddings: int, - decay: float = 0.99, + quantizer: VectorQuantizer, ) -> None: super().__init__() self.encoder = encoder self.decoder = decoder - self.pre_codebook_conv = nn.Conv2d( - in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1 - ) - self.post_codebook_conv = nn.Conv2d( - in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1 - ) - self.quantizer = VectorQuantizer( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, - ) + self.quantizer = quantizer def encode(self, x: Tensor) -> Tensor: """Encodes input to a latent code.""" - z_e = self.encoder(x) - return self.pre_codebook_conv(z_e) + return self.encoder(x) def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: """Quantizes the encoded latent vectors.""" - z_q, commitment_loss = self.quantizer(z_e) + z_q, _, commitment_loss = self.quantizer(z_e) return z_q, commitment_loss def decode(self, z_q: Tensor) -> Tensor: """Reconstructs input from latent codes.""" - z = self.post_codebook_conv(z_q) - x_hat = self.decoder(z) - return x_hat + return self.decoder(z_q) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" -- cgit v1.2.3-70-g09d2