summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/vqvae/vqvae.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
1 files changed, 0 insertions, 74 deletions
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
deleted file mode 100644
index 50448b4..0000000
--- a/src/text_recognizer/networks/vqvae/vqvae.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""The VQ-VAE."""
-
-from typing import List, Optional, Tuple, Type
-
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.vqvae import Decoder, Encoder
-
-
-class VQVAE(nn.Module):
- """Vector Quantized Variational AutoEncoder."""
-
- def __init__(
- self,
- in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- num_embeddings: int,
- upsampling: Optional[List[List[int]]] = None,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- # configure encoder.
- self.encoder = Encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- num_embeddings,
- beta,
- activation,
- dropout_rate,
- )
-
- # Configure decoder.
- channels.reverse()
- kernel_sizes.reverse()
- strides.reverse()
- self.decoder = Decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- upsampling,
- activation,
- dropout_rate,
- )
-
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes input to a latent code."""
- return self.encoder(x)
-
- def decode(self, z_q: Tensor) -> Tensor:
- """Reconstructs input from latent codes."""
- return self.decoder(z_q)
-
- def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- """Compresses and decompresses input."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- z_q, vq_loss = self.encode(x)
- x_reconstruction = self.decode(z_q)
- return x_reconstruction, vq_loss