diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/networks/vqvae/vqvae.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss |