diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
commit | d3afa310f77f47553586eeee58e3d3345a754e2c (patch) | |
tree | 08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /text_recognizer | |
parent | 65d5f6c694e73792e40ed693a1381a792da8d277 (diff) |
New VQVAE
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/vqvae.py | 16 | ||||
-rw-r--r-- | text_recognizer/networks/vq_transformer.py | 77 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 164 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 176 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py (renamed from text_recognizer/networks/vqvae/vector_quantizer.py) | 51 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/residual.py | 18 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 122 |
8 files changed, 319 insertions, 308 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 22da018..5890fd9 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -14,31 +14,33 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + latent_loss_weight: float = attr.ib(default=0.25) + def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.network(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("val/loss", loss, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("test/loss", loss) diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..a972565 --- /dev/null +++ b/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,77 @@ +"""Vector quantized encoder, transformer decoder.""" +import math +from typing import Tuple + +from torch import nn, Tensor + +from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.conv_transformer import ConvTransformer +from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.positional_encodings import ( + PositionalEncoding, + PositionalEncoding2D, +) + + +class VqTransformer(ConvTransformer): + """Convolutional encoder and transformer decoder network.""" + + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + dropout_rate: float, + num_classes: int, + pad_index: Tensor, + encoder: EfficientNet, + decoder: Decoder, + ) -> None: + # TODO: Load pretrained vqvae encoder. + super().__init__( + input_dims=input_dims, + hidden_dim=hidden_dim, + dropout_rate=dropout_rate, + num_classes=num_classes, + pad_index=pad_index, + encoder=encoder, + decoder=decoder, + ) + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py index 763953c..7d56bdb 100644 --- a/text_recognizer/networks/vqvae/__init__.py +++ b/text_recognizer/networks/vqvae/__init__.py @@ -1,5 +1,2 @@ """VQ-VAE module.""" -from .decoder import Decoder -from .encoder import Encoder -from .vector_quantizer import VectorQuantizer from .vqvae import VQVAE diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 32de912..3f59f0d 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -1,133 +1,65 @@ """CNN decoder for the VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch +import attr from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.encoder import _ResidualBlock +from text_recognizer.networks.vqvae.residual import Residual +@attr.s(eq=False) class Decoder(nn.Module): """A CNN encoder network.""" - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - upsampling: Optional[List[List[int]]] = None, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.upsampling = upsampling - - self.res_block = nn.ModuleList([]) - self.upsampling_block = nn.ModuleList([]) - - self.embedding_dim = embedding_dim - activation = activation_function(activation) - - # Configure encoder. - self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, - ) - - def _build_decompression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - modules.append( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - activation, - ) - ) - - if self.upsampling and i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + in_channels: int = attr.ib() + embedding_dim: int = attr.ib() + out_channels: int = attr.ib() + res_channels: int = attr.ib() + num_residual_layers: int = attr.ib() + activation: str = attr.ib() + decoder: nn.Sequential = attr.ib(init=False) - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - modules.extend( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 - ), - nn.Tanh(), - ) - ) - - return modules - - def _build_decoder( - self, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - - self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) - ) + def __attrs_pre_init__(self) -> None: + super().__init__() - # Bottleneck module. - self.res_block.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[0], channels[0], dropout) - for i in range(num_residual_layers) - ] + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.decoder = self._build_decompression_block() + + def _build_decompression_block(self,) -> nn.Sequential: + activation_fn = activation_function(self.activation) + blocks = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embedding_dim, + kernel_size=3, + padding=1, ) - ) - - # Decompression module - self.upsampling_block.extend( - self._build_decompression_block( - channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ] + for _ in range(self.num_residual_layers): + blocks.append( + Residual(in_channels=self.embedding_dim, out_channels=self.res_channels) ) - ) - - self.res_block = nn.Sequential(*self.res_block) - self.upsampling_block = nn.Sequential(*self.upsampling_block) - - return nn.Sequential(self.res_block, self.upsampling_block) + blocks.append(activation_fn) + blocks += [ + nn.ConvTranspose2d( + in_channels=self.embedding_dim, + out_channels=self.embedding_dim // 2, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.ConvTranspose2d( + in_channels=self.embedding_dim // 2, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + ] + return nn.Sequential(*blocks) def forward(self, z_q: Tensor) -> Tensor: """Reconstruct input from given codes.""" - x_reconstruction = self.decoder(z_q) - return x_reconstruction + return self.decoder(z_q) diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index 65801df..e480545 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,147 +1,75 @@ """CNN encoder for the VQ-VAE.""" from typing import Sequence, Optional, Tuple, Type -import torch +import attr from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) +from text_recognizer.networks.vqvae.residual import Residual +@attr.s(eq=False) class Encoder(nn.Module): """A CNN encoder network.""" - def __init__( - self, - in_channels: int, - channels: Sequence[int], - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) + in_channels: int = attr.ib() + out_channels: int = attr.ib() + res_channels: int = attr.ib() + num_residual_layers: int = attr.ib() + embedding_dim: int = attr.ib() + activation: str = attr.ib() + encoder: nn.Sequential = attr.ib(init=False) - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - @staticmethod - def _build_compression_block( - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules + def __attrs_pre_init__(self) -> None: + super().__init__() - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.encoder = self._build_compression_block() + + def _build_compression_block(self) -> nn.Sequential: + activation_fn = activation_function(self.activation) + block = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels // 2, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels // 2, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + ), + ] - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout + for _ in range(self.num_residual_layers): + block.append( + Residual(in_channels=self.out_channels, out_channels=self.res_channels) ) - ) - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] + block.append( + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.embedding_dim, + kernel_size=1, ) ) - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) + return nn.Sequential(*block) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss + return self.encoder(x) diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index f92c7ee..5e0b602 100644 --- a/text_recognizer/networks/vqvae/vector_quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -2,9 +2,7 @@ Reference: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py - """ - from einops import rearrange import torch from torch import nn @@ -12,21 +10,27 @@ from torch import Tensor from torch.nn import functional as F +class EmbeddingEMA(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: + super().__init__() + weight = torch.zeros(num_embeddings, embedding_dim) + nn.init.kaiming_uniform_(weight, nonlinearity="linear") + self.register_buffer("weight", weight) + self.register_buffer("_cluster_size", torch.zeros(num_embeddings)) + self.register_buffer("_weight_avg", weight) + + class VectorQuantizer(nn.Module): """The codebook that contains quantized vectors.""" def __init__( - self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 + self, num_embeddings: int, embedding_dim: int, decay: float = 0.99 ) -> None: super().__init__() - self.K = num_embeddings - self.D = embedding_dim - self.beta = beta - - self.embedding = nn.Embedding(self.K, self.D) - - # Initialize the codebook. - nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.decay = decay + self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) def discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. @@ -62,7 +66,7 @@ class VectorQuantizer(nn.Module): # Convert to one-hot encodings, aka discrete bottleneck. one_hot_encoding = torch.zeros( - encoding_indices.shape[0], self.K, device=latent.device + encoding_indices.shape[0], self.num_embeddings, device=latent.device ) one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] @@ -71,9 +75,27 @@ class VectorQuantizer(nn.Module): quantized_latent = rearrange( quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w ) + if self.training: + self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) return quantized_latent + def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + batch_cluster_size = one_hot_encoding.sum(axis=0) + batch_embedding_avg = (latent.t() @ one_hot_encoding).t() + print(batch_cluster_size.shape) + print(self.embedding._cluster_size.shape) + self.embedding._cluster_size.data.mul_(self.decay).add_( + batch_cluster_size, alpha=1 - self.decay + ) + self.embedding._weight_avg.data.mul_(self.decay).add_( + batch_embedding_avg, alpha=1 - self.decay + ) + new_embedding = self.embedding._weight_avg / ( + self.embedding._cluster_size + 1.0e-5 + ).unsqueeze(1) + self.embedding.weight.data.copy_(new_embedding) + def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: """Vector Quantization loss. @@ -96,9 +118,10 @@ class VectorQuantizer(nn.Module): Tensor: The combinded VQ loss. """ - embedding_loss = F.mse_loss(quantized_latent, latent.detach()) commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - return embedding_loss + self.beta * commitment_loss + # embedding_loss = F.mse_loss(quantized_latent, latent.detach()) + # return embedding_loss + self.beta * commitment_loss + return commitment_loss def forward(self, latent: Tensor) -> Tensor: """Forward pass that returns the quantized vector and the vq loss.""" diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py new file mode 100644 index 0000000..98109b8 --- /dev/null +++ b/text_recognizer/networks/vqvae/residual.py @@ -0,0 +1,18 @@ +"""Residual block.""" +from torch import nn +from torch import Tensor + + +class Residual(nn.Module): + def __init__(self, in_channels: int, out_channels: int,) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Mish(inplace=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.Mish(inplace=True), + nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 5aa929b..1585d40 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,10 +1,14 @@ """The VQ-VAE.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple +import torch from torch import nn from torch import Tensor +import torch.nn.functional as F -from text_recognizer.networks.vqvae import Decoder, Encoder +from text_recognizer.networks.vqvae.decoder import Decoder +from text_recognizer.networks.vqvae.encoder import Encoder +from text_recognizer.networks.vqvae.quantizer import VectorQuantizer class VQVAE(nn.Module): @@ -13,62 +17,92 @@ class VQVAE(nn.Module): def __init__( self, in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], + res_channels: int, num_residual_layers: int, embedding_dim: int, num_embeddings: int, - upsampling: Optional[List[List[int]]] = None, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - *args: Any, - **kwargs: Dict, + decay: float = 0.99, + activation: str = "mish", ) -> None: super().__init__() + # Encoders + self.btm_encoder = Encoder( + in_channels=1, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + self.top_encoder = Encoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + # Quantizers + self.btm_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, + ) - # configure encoder. - self.encoder = Encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - num_embeddings, - beta, - activation, - dropout_rate, + self.top_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - # Configure decoder. - channels.reverse() - kernel_sizes.reverse() - strides.reverse() - self.decoder = Decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - upsampling, - activation, - dropout_rate, + # Decoders + self.top_decoder = Decoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, + ) + + self.btm_decoder = Decoder( + in_channels=2 * embedding_dim, + out_channels=in_channels, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, ) def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input to a latent code.""" - return self.encoder(x) + z_btm = self.btm_encoder(x) + z_top = self.top_encoder(z_btm) + return z_btm, z_top + + def quantize( + self, z_btm: Tensor, z_top: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + q_btm, vq_btm_loss = self.top_quantizer(z_btm) + q_top, vq_top_loss = self.top_quantizer(z_top) + return q_btm, vq_btm_loss, q_top, vq_top_loss - def decode(self, z_q: Tensor) -> Tensor: + def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]: """Reconstructs input from latent codes.""" - return self.decoder(z_q) + d_top = self.top_decoder(q_top) + x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1)) + return d_top, x_hat + + def loss_fn( + self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor + ) -> Tensor: + """Calculates the latent loss.""" + return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - z_q, vq_loss = self.encode(x) - x_reconstruction = self.decode(z_q) - return x_reconstruction, vq_loss + z_btm, z_top = self.encode(x) + q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top) + d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top) + vq_loss = self.loss_fn( + vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm + ) + return x_hat, vq_loss |