diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/networks/vqvae/vqvae.py | |
parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 42 |
1 files changed, 0 insertions, 42 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py deleted file mode 100644 index 5560e12..0000000 --- a/text_recognizer/networks/vqvae/vqvae.py +++ /dev/null @@ -1,42 +0,0 @@ -"""The VQ-VAE.""" -from typing import Tuple - -from torch import nn -from torch import Tensor - -from text_recognizer.networks.quantizer.quantizer import VectorQuantizer - - -class VQVAE(nn.Module): - """Vector Quantized Variational AutoEncoder.""" - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - quantizer: VectorQuantizer, - ) -> None: - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.quantizer = quantizer - - def encode(self, x: Tensor) -> Tensor: - """Encodes input to a latent code.""" - return self.encoder(x) - - def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: - """Quantizes the encoded latent vectors.""" - z_q, _, commitment_loss = self.quantizer(z_e) - return z_q, commitment_loss - - def decode(self, z_q: Tensor) -> Tensor: - """Reconstructs input from latent codes.""" - return self.decoder(z_q) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Compresses and decompresses input.""" - z_e = self.encode(x) - z_q, commitment_loss = self.quantize(z_e) - x_hat = self.decode(z_q) - return x_hat, commitment_loss |