summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/vqvae.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py122
1 files changed, 78 insertions, 44 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 5aa929b..1585d40 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,10 +1,14 @@
"""The VQ-VAE."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Tuple
+import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
-from text_recognizer.networks.vqvae import Decoder, Encoder
+from text_recognizer.networks.vqvae.decoder import Decoder
+from text_recognizer.networks.vqvae.encoder import Encoder
+from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
class VQVAE(nn.Module):
@@ -13,62 +17,92 @@ class VQVAE(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ res_channels: int,
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
- upsampling: Optional[List[List[int]]] = None,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- *args: Any,
- **kwargs: Dict,
+ decay: float = 0.99,
+ activation: str = "mish",
) -> None:
super().__init__()
+ # Encoders
+ self.btm_encoder = Encoder(
+ in_channels=1,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ self.top_encoder = Encoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ # Quantizers
+ self.btm_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
+ )
- # configure encoder.
- self.encoder = Encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- num_embeddings,
- beta,
- activation,
- dropout_rate,
+ self.top_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
)
- # Configure decoder.
- channels.reverse()
- kernel_sizes.reverse()
- strides.reverse()
- self.decoder = Decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- upsampling,
- activation,
- dropout_rate,
+ # Decoders
+ self.top_decoder = Decoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
+ )
+
+ self.btm_decoder = Decoder(
+ in_channels=2 * embedding_dim,
+ out_channels=in_channels,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
)
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input to a latent code."""
- return self.encoder(x)
+ z_btm = self.btm_encoder(x)
+ z_top = self.top_encoder(z_btm)
+ return z_btm, z_top
+
+ def quantize(
+ self, z_btm: Tensor, z_top: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ q_btm, vq_btm_loss = self.top_quantizer(z_btm)
+ q_top, vq_top_loss = self.top_quantizer(z_top)
+ return q_btm, vq_btm_loss, q_top, vq_top_loss
- def decode(self, z_q: Tensor) -> Tensor:
+ def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
"""Reconstructs input from latent codes."""
- return self.decoder(z_q)
+ d_top = self.top_decoder(q_top)
+ x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
+ return d_top, x_hat
+
+ def loss_fn(
+ self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
+ ) -> Tensor:
+ """Calculates the latent loss."""
+ return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- z_q, vq_loss = self.encode(x)
- x_reconstruction = self.decode(z_q)
- return x_reconstruction, vq_loss
+ z_btm, z_top = self.encode(x)
+ q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
+ d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
+ vq_loss = self.loss_fn(
+ vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
+ )
+ return x_hat, vq_loss