summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/vq_transformer.py77
-rw-r--r--text_recognizer/networks/vqvae/__init__.py3
-rw-r--r--text_recognizer/networks/vqvae/decoder.py164
-rw-r--r--text_recognizer/networks/vqvae/encoder.py176
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py (renamed from text_recognizer/networks/vqvae/vector_quantizer.py)51
-rw-r--r--text_recognizer/networks/vqvae/residual.py18
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py122
7 files changed, 310 insertions, 301 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..a972565
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,77 @@
+"""Vector quantized encoder, transformer decoder."""
+import math
+from typing import Tuple
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.conv_transformer import ConvTransformer
+from text_recognizer.networks.transformer.layers import Decoder
+from text_recognizer.networks.transformer.positional_encodings import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+)
+
+
+class VqTransformer(ConvTransformer):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
+ # TODO: Load pretrained vqvae encoder.
+ super().__init__(
+ input_dims=input_dims,
+ hidden_dim=hidden_dim,
+ dropout_rate=dropout_rate,
+ num_classes=num_classes,
+ pad_index=pad_index,
+ encoder=encoder,
+ decoder=decoder,
+ )
+ # Latent projector for down sampling number of filters and 2d
+ # positional encoding.
+ self.latent_encoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
+ ),
+ PositionalEncoding2D(
+ hidden_dim=self.hidden_dim,
+ max_h=self.input_dims[1],
+ max_w=self.input_dims[2],
+ ),
+ nn.Flatten(start_dim=2),
+ )
+
+ def encode(self, x: Tensor) -> Tensor:
+ """Encodes an image into a latent feature vector.
+
+ Args:
+ x (Tensor): Image tensor.
+
+ Shape:
+ - x: :math: `(B, C, H, W)`
+ - z: :math: `(B, Sx, E)`
+
+ where Sx is the length of the flattened feature maps projected from
+ the encoder. E latent dimension for each pixel in the projected
+ feature maps.
+
+ Returns:
+ Tensor: A Latent embedding of the image.
+ """
+ z = self.encoder(x)
+ z = self.latent_encoder(z)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 763953c..7d56bdb 100644
--- a/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -1,5 +1,2 @@
"""VQ-VAE module."""
-from .decoder import Decoder
-from .encoder import Encoder
-from .vector_quantizer import VectorQuantizer
from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 32de912..3f59f0d 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,133 +1,65 @@
"""CNN decoder for the VQ-VAE."""
-
-from typing import List, Optional, Tuple, Type
-
-import torch
+import attr
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+from text_recognizer.networks.vqvae.residual import Residual
+@attr.s(eq=False)
class Decoder(nn.Module):
"""A CNN encoder network."""
- def __init__(
- self,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- upsampling: Optional[List[List[int]]] = None,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.upsampling = upsampling
-
- self.res_block = nn.ModuleList([])
- self.upsampling_block = nn.ModuleList([])
-
- self.embedding_dim = embedding_dim
- activation = activation_function(activation)
-
- # Configure encoder.
- self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
- )
-
- def _build_decompression_block(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for i, (out_channels, kernel_size, stride) in enumerate(configuration):
- modules.append(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=1,
- ),
- activation,
- )
- )
-
- if self.upsampling and i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ in_channels: int = attr.ib()
+ embedding_dim: int = attr.ib()
+ out_channels: int = attr.ib()
+ res_channels: int = attr.ib()
+ num_residual_layers: int = attr.ib()
+ activation: str = attr.ib()
+ decoder: nn.Sequential = attr.ib(init=False)
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- modules.extend(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
- ),
- nn.Tanh(),
- )
- )
-
- return modules
-
- def _build_decoder(
- self,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
-
- self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
- )
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
- # Bottleneck module.
- self.res_block.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[0], channels[0], dropout)
- for i in range(num_residual_layers)
- ]
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.decoder = self._build_decompression_block()
+
+ def _build_decompression_block(self,) -> nn.Sequential:
+ activation_fn = activation_function(self.activation)
+ blocks = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.embedding_dim,
+ kernel_size=3,
+ padding=1,
)
- )
-
- # Decompression module
- self.upsampling_block.extend(
- self._build_decompression_block(
- channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ ]
+ for _ in range(self.num_residual_layers):
+ blocks.append(
+ Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
)
- )
-
- self.res_block = nn.Sequential(*self.res_block)
- self.upsampling_block = nn.Sequential(*self.upsampling_block)
-
- return nn.Sequential(self.res_block, self.upsampling_block)
+ blocks.append(activation_fn)
+ blocks += [
+ nn.ConvTranspose2d(
+ in_channels=self.embedding_dim,
+ out_channels=self.embedding_dim // 2,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.ConvTranspose2d(
+ in_channels=self.embedding_dim // 2,
+ out_channels=self.out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*blocks)
def forward(self, z_q: Tensor) -> Tensor:
"""Reconstruct input from given codes."""
- x_reconstruction = self.decoder(z_q)
- return x_reconstruction
+ return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index 65801df..e480545 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,147 +1,75 @@
"""CNN encoder for the VQ-VAE."""
from typing import Sequence, Optional, Tuple, Type
-import torch
+import attr
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
-
-
-class _ResidualBlock(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
- ) -> None:
- super().__init__()
- self.block = [
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
- ]
-
- if dropout is not None:
- self.block.append(dropout)
-
- self.block = nn.Sequential(*self.block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
+from text_recognizer.networks.vqvae.residual import Residual
+@attr.s(eq=False)
class Encoder(nn.Module):
"""A CNN encoder network."""
- def __init__(
- self,
- in_channels: int,
- channels: Sequence[int],
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- num_residual_layers: int,
- embedding_dim: int,
- num_embeddings: int,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.embedding_dim = embedding_dim
- self.num_embeddings = num_embeddings
- self.beta = beta
- activation = activation_function(activation)
-
- # Configure encoder.
- self.encoder = self._build_encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
- )
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ res_channels: int = attr.ib()
+ num_residual_layers: int = attr.ib()
+ embedding_dim: int = attr.ib()
+ activation: str = attr.ib()
+ encoder: nn.Sequential = attr.ib(init=False)
- # Configure Vector Quantizer.
- self.vector_quantizer = VectorQuantizer(
- self.num_embeddings, self.embedding_dim, self.beta
- )
-
- @staticmethod
- def _build_compression_block(
- in_channels: int,
- channels: int,
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for out_channels, kernel_size, stride in configuration:
- modules.append(
- nn.Sequential(
- nn.Conv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=1
- ),
- activation,
- )
- )
-
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- return modules
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
- def _build_encoder(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
- encoder = nn.ModuleList([])
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.encoder = self._build_compression_block()
+
+ def _build_compression_block(self) -> nn.Sequential:
+ activation_fn = activation_function(self.activation)
+ block = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels // 2,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=self.out_channels // 2,
+ out_channels=self.out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ ]
- # compression module
- encoder.extend(
- self._build_compression_block(
- in_channels, channels, kernel_sizes, strides, activation, dropout
+ for _ in range(self.num_residual_layers):
+ block.append(
+ Residual(in_channels=self.out_channels, out_channels=self.res_channels)
)
- )
- # Bottleneck module.
- encoder.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[-1], channels[-1], dropout)
- for i in range(num_residual_layers)
- ]
+ block.append(
+ nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.embedding_dim,
+ kernel_size=1,
)
)
- encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
- )
-
- return nn.Sequential(*encoder)
+ return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input into a discrete representation."""
- z_e = self.encoder(x)
- z_q, vq_loss = self.vector_quantizer(z_e)
- return z_q, vq_loss
+ return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index f92c7ee..5e0b602 100644
--- a/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -2,9 +2,7 @@
Reference:
https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-
"""
-
from einops import rearrange
import torch
from torch import nn
@@ -12,21 +10,27 @@ from torch import Tensor
from torch.nn import functional as F
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
+ super().__init__()
+ weight = torch.zeros(num_embeddings, embedding_dim)
+ nn.init.kaiming_uniform_(weight, nonlinearity="linear")
+ self.register_buffer("weight", weight)
+ self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
+ self.register_buffer("_weight_avg", weight)
+
+
class VectorQuantizer(nn.Module):
"""The codebook that contains quantized vectors."""
def __init__(
- self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+ self, num_embeddings: int, embedding_dim: int, decay: float = 0.99
) -> None:
super().__init__()
- self.K = num_embeddings
- self.D = embedding_dim
- self.beta = beta
-
- self.embedding = nn.Embedding(self.K, self.D)
-
- # Initialize the codebook.
- nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.decay = decay
+ self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
@@ -62,7 +66,7 @@ class VectorQuantizer(nn.Module):
# Convert to one-hot encodings, aka discrete bottleneck.
one_hot_encoding = torch.zeros(
- encoding_indices.shape[0], self.K, device=latent.device
+ encoding_indices.shape[0], self.num_embeddings, device=latent.device
)
one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
@@ -71,9 +75,27 @@ class VectorQuantizer(nn.Module):
quantized_latent = rearrange(
quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
)
+ if self.training:
+ self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
return quantized_latent
+ def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+ batch_cluster_size = one_hot_encoding.sum(axis=0)
+ batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
+ print(batch_cluster_size.shape)
+ print(self.embedding._cluster_size.shape)
+ self.embedding._cluster_size.data.mul_(self.decay).add_(
+ batch_cluster_size, alpha=1 - self.decay
+ )
+ self.embedding._weight_avg.data.mul_(self.decay).add_(
+ batch_embedding_avg, alpha=1 - self.decay
+ )
+ new_embedding = self.embedding._weight_avg / (
+ self.embedding._cluster_size + 1.0e-5
+ ).unsqueeze(1)
+ self.embedding.weight.data.copy_(new_embedding)
+
def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
"""Vector Quantization loss.
@@ -96,9 +118,10 @@ class VectorQuantizer(nn.Module):
Tensor: The combinded VQ loss.
"""
- embedding_loss = F.mse_loss(quantized_latent, latent.detach())
commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
- return embedding_loss + self.beta * commitment_loss
+ # embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+ # return embedding_loss + self.beta * commitment_loss
+ return commitment_loss
def forward(self, latent: Tensor) -> Tensor:
"""Forward pass that returns the quantized vector and the vq loss."""
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
new file mode 100644
index 0000000..98109b8
--- /dev/null
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -0,0 +1,18 @@
+"""Residual block."""
+from torch import nn
+from torch import Tensor
+
+
+class Residual(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int,) -> None:
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Mish(inplace=True),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.Mish(inplace=True),
+ nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the residual forward pass."""
+ return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 5aa929b..1585d40 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,10 +1,14 @@
"""The VQ-VAE."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Tuple
+import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
-from text_recognizer.networks.vqvae import Decoder, Encoder
+from text_recognizer.networks.vqvae.decoder import Decoder
+from text_recognizer.networks.vqvae.encoder import Encoder
+from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
class VQVAE(nn.Module):
@@ -13,62 +17,92 @@ class VQVAE(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ res_channels: int,
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
- upsampling: Optional[List[List[int]]] = None,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- *args: Any,
- **kwargs: Dict,
+ decay: float = 0.99,
+ activation: str = "mish",
) -> None:
super().__init__()
+ # Encoders
+ self.btm_encoder = Encoder(
+ in_channels=1,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ self.top_encoder = Encoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ # Quantizers
+ self.btm_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
+ )
- # configure encoder.
- self.encoder = Encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- num_embeddings,
- beta,
- activation,
- dropout_rate,
+ self.top_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
)
- # Configure decoder.
- channels.reverse()
- kernel_sizes.reverse()
- strides.reverse()
- self.decoder = Decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- upsampling,
- activation,
- dropout_rate,
+ # Decoders
+ self.top_decoder = Decoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
+ )
+
+ self.btm_decoder = Decoder(
+ in_channels=2 * embedding_dim,
+ out_channels=in_channels,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
)
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input to a latent code."""
- return self.encoder(x)
+ z_btm = self.btm_encoder(x)
+ z_top = self.top_encoder(z_btm)
+ return z_btm, z_top
+
+ def quantize(
+ self, z_btm: Tensor, z_top: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ q_btm, vq_btm_loss = self.top_quantizer(z_btm)
+ q_top, vq_top_loss = self.top_quantizer(z_top)
+ return q_btm, vq_btm_loss, q_top, vq_top_loss
- def decode(self, z_q: Tensor) -> Tensor:
+ def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
"""Reconstructs input from latent codes."""
- return self.decoder(z_q)
+ d_top = self.top_decoder(q_top)
+ x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
+ return d_top, x_hat
+
+ def loss_fn(
+ self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
+ ) -> Tensor:
+ """Calculates the latent loss."""
+ return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- z_q, vq_loss = self.encode(x)
- x_reconstruction = self.decode(z_q)
- return x_reconstruction, vq_loss
+ z_btm, z_top = self.encode(x)
+ q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
+ d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
+ vq_loss = self.loss_fn(
+ vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
+ )
+ return x_hat, vq_loss