summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/__init__.py1
-rw-r--r--text_recognizer/networks/vqvae/attention.py7
-rw-r--r--text_recognizer/networks/vqvae/decoder.py83
-rw-r--r--text_recognizer/networks/vqvae/encoder.py82
-rw-r--r--text_recognizer/networks/vqvae/norm.py4
-rw-r--r--text_recognizer/networks/vqvae/pixelcnn.py165
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py15
-rw-r--r--text_recognizer/networks/vqvae/residual.py53
-rw-r--r--text_recognizer/networks/vqvae/resize.py2
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py98
10 files changed, 332 insertions, 178 deletions
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 7d56bdb..e1f05fa 100644
--- a/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -1,2 +1 @@
"""VQ-VAE module."""
-from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py
index 5a6b3ce..78a2cc9 100644
--- a/text_recognizer/networks/vqvae/attention.py
+++ b/text_recognizer/networks/vqvae/attention.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from text_recognizer.networks.vqvae.norm import Normalize
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Convolutional attention."""
@@ -63,11 +63,12 @@ class Attention(nn.Module):
B, C, H, W = q.shape
q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
k = k.reshape(B, C, H * W) # [B, C, HW]
- energy = torch.bmm(q, k) * (C ** -0.5)
+ energy = torch.bmm(q, k) * (int(C) ** -0.5)
attention = F.softmax(energy, dim=2)
# Compute attention to which values
- v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C]
+ v = v.reshape(B, C, H * W)
+ attention = attention.permute(0, 2, 1) # [B, HW, HW]
out = torch.bmm(v, attention)
out = out.reshape(B, C, H, W)
out = self.proj(out)
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index fcf768b..f51e0a3 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,62 +1,69 @@
"""CNN decoder for the VQ-VAE."""
-import attr
+from typing import Sequence
+
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.norm import Normalize
from text_recognizer.networks.vqvae.residual import Residual
-@attr.s(eq=False)
class Decoder(nn.Module):
"""A CNN encoder network."""
- in_channels: int = attr.ib()
- embedding_dim: int = attr.ib()
- out_channels: int = attr.ib()
- res_channels: int = attr.ib()
- num_residual_layers: int = attr.ib()
- activation: str = attr.ib()
- decoder: nn.Sequential = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None:
super().__init__()
+ self.out_channels = out_channels
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.activation = activation
+ self.dropout_rate = dropout_rate
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = []
+ for _ in range(2):
+ decoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
activation_fn = activation_function(self.activation)
- blocks = [
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder += [
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ ]
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.embedding_dim,
- kernel_size=3,
- padding=1,
- )
- ]
- for _ in range(self.num_residual_layers):
- blocks.append(
- Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
- )
- blocks.append(activation_fn)
- blocks += [
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim,
- out_channels=self.embedding_dim // 2,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim // 2,
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
out_channels=self.out_channels,
- kernel_size=4,
- stride=2,
+ kernel_size=3,
+ stride=1,
padding=1,
),
]
- return nn.Sequential(*blocks)
+ return nn.Sequential(*decoder)
def forward(self, z_q: Tensor) -> Tensor:
"""Reconstruct input from given codes."""
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index f086c6b..ad8f950 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,7 +1,6 @@
"""CNN encoder for the VQ-VAE."""
-from typing import Sequence, Optional, Tuple, Type
+from typing import List, Tuple
-import attr
from torch import nn
from torch import Tensor
@@ -9,64 +8,59 @@ from text_recognizer.networks.util import activation_function
from text_recognizer.networks.vqvae.residual import Residual
-@attr.s(eq=False)
class Encoder(nn.Module):
"""A CNN encoder network."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- res_channels: int = attr.ib()
- num_residual_layers: int = attr.ib()
- embedding_dim: int = attr.ib()
- activation: str = attr.ib()
- encoder: nn.Sequential = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None:
super().__init__()
+ self.in_channels = in_channels
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.activation = activation
+ self.dropout_rate = dropout_rate
self.encoder = self._build_compression_block()
def _build_compression_block(self) -> nn.Sequential:
- activation_fn = activation_function(self.activation)
- block = [
+ """Builds encoder network."""
+ encoder = [
nn.Conv2d(
in_channels=self.in_channels,
- out_channels=self.out_channels // 2,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=self.out_channels // 2,
- out_channels=self.out_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.Conv2d(
- in_channels=self.out_channels,
- out_channels=self.out_channels,
+ out_channels=self.hidden_dim,
kernel_size=3,
+ stride=1,
padding=1,
),
]
- for _ in range(self.num_residual_layers):
- block.append(
- Residual(in_channels=self.out_channels, out_channels=self.res_channels)
- )
+ num_blocks = len(self.channels_multipliers)
+ channels_multipliers = (1, ) + self.channels_multipliers
+ activation_fn = activation_function(self.activation)
- block.append(
- nn.Conv2d(
- in_channels=self.out_channels,
- out_channels=self.embedding_dim,
- kernel_size=1,
- )
- )
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * channels_multipliers[i]
+ out_channels = self.hidden_dim * channels_multipliers[i + 1]
+ encoder += [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ ]
+
+ for _ in range(2):
+ encoder += [
+ Residual(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ )
+ ]
- return nn.Sequential(*block)
+ return nn.Sequential(*encoder)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input into a discrete representation."""
diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py
index df66efc..3e6963a 100644
--- a/text_recognizer/networks/vqvae/norm.py
+++ b/text_recognizer/networks/vqvae/norm.py
@@ -3,7 +3,7 @@ import attr
from torch import nn, Tensor
-@attr.s
+@attr.s(eq=False)
class Normalize(nn.Module):
num_channels: int = attr.ib()
norm: nn.GroupNorm = attr.ib(init=False)
@@ -12,7 +12,7 @@ class Normalize(nn.Module):
"""Post init configuration."""
super().__init__()
self.norm = nn.GroupNorm(
- num_groups=32, num_channels=self.num_channels, eps=1.0e-6, affine=True
+ num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True
)
def forward(self, x: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
new file mode 100644
index 0000000..5c580df
--- /dev/null
+++ b/text_recognizer/networks/vqvae/pixelcnn.py
@@ -0,0 +1,165 @@
+"""PixelCNN encoder and decoder.
+
+Same as in VQGAN among other. Hopefully, better reconstructions...
+
+TODO: Add num of residual layers.
+"""
+from typing import Sequence
+
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae.attention import Attention
+from text_recognizer.networks.vqvae.norm import Normalize
+from text_recognizer.networks.vqvae.residual import Residual
+from text_recognizer.networks.vqvae.resize import Downsample, Upsample
+
+
+class Encoder(nn.Module):
+ """PixelCNN encoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.dropout_rate = dropout_rate
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.encoder = self._build_encoder()
+
+ def _build_encoder(self) -> nn.Sequential:
+ """Builds encoder."""
+ encoder = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ num_blocks = len(self.channels_multipliers)
+ in_channels_multipliers = (1,) + self.channels_multipliers
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * in_channels_multipliers[i]
+ out_channels = self.hidden_dim * self.channels_multipliers[i]
+ encoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+ if i == num_blocks - 1:
+ encoder.append(Attention(in_channels=out_channels))
+ encoder.append(Downsample())
+
+ for _ in range(2):
+ encoder += [
+ Residual(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1])
+ ]
+
+ encoder += [
+ Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes input to a latent representation."""
+ return self.encoder(x)
+
+
+class Decoder(nn.Module):
+ """PixelCNN decoder."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ out_channels: int,
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.out_channels = out_channels
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.dropout_rate = dropout_rate
+ self.decoder = self._build_decoder()
+
+ def _build_decoder(self) -> nn.Sequential:
+ """Builds decoder."""
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=in_channels),
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder.append(
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ )
+ )
+ if i == 0:
+ decoder.append(
+ Attention(
+ in_channels=out_channels
+ )
+ )
+ decoder.append(Upsample())
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*decoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Decodes latent vector."""
+ return self.decoder(x)
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index a4f11f0..6fb57e8 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -11,13 +11,15 @@ import torch.nn.functional as F
class EmbeddingEMA(nn.Module):
+ """Embedding for Exponential Moving Average (EMA)."""
+
def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
super().__init__()
weight = torch.zeros(num_embeddings, embedding_dim)
nn.init.kaiming_uniform_(weight, nonlinearity="linear")
self.register_buffer("weight", weight)
- self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
- self.register_buffer("_weight_avg", weight)
+ self.register_buffer("cluster_size", torch.zeros(num_embeddings))
+ self.register_buffer("weight_avg", weight.clone())
class VectorQuantizer(nn.Module):
@@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module):
return quantized_latent
def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+ """Computes the EMA update to the codebook."""
batch_cluster_size = one_hot_encoding.sum(axis=0)
batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
- self.embedding._cluster_size.data.mul_(self.decay).add_(
+ self.embedding.cluster_size.data.mul_(self.decay).add_(
batch_cluster_size, alpha=1 - self.decay
)
- self.embedding._weight_avg.data.mul_(self.decay).add_(
+ self.embedding.weight_avg.data.mul_(self.decay).add_(
batch_embedding_avg, alpha=1 - self.decay
)
- new_embedding = self.embedding._weight_avg / (
- self.embedding._cluster_size + 1.0e-5
+ new_embedding = self.embedding.weight_avg / (
+ self.embedding.cluster_size + 1.0e-5
).unsqueeze(1)
self.embedding.weight.data.copy_(new_embedding)
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
index 98109b8..4ed3781 100644
--- a/text_recognizer/networks/vqvae/residual.py
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -1,18 +1,55 @@
"""Residual block."""
+import attr
from torch import nn
from torch import Tensor
+from text_recognizer.networks.vqvae.norm import Normalize
+
+@attr.s(eq=False)
class Residual(nn.Module):
- def __init__(self, in_channels: int, out_channels: int,) -> None:
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ dropout_rate: float = attr.ib(default=0.0)
+ use_norm: bool = attr.ib(default=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
super().__init__()
- self.block = nn.Sequential(
- nn.Mish(inplace=True),
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.Mish(inplace=True),
- nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
- )
+ self.block = self._build_res_block()
+ if self.in_channels != self.out_channels:
+ self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.conv_shortcut = None
+
+ def _build_res_block(self) -> nn.Sequential:
+ """Build residual block."""
+ block = []
+ if self.use_norm:
+ block.append(Normalize(num_channels=self.in_channels))
+ block += [
+ nn.Mish(),
+ nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ]
+ if self.dropout_rate:
+ block += [nn.Dropout(p=self.dropout_rate)]
+
+ if self.use_norm:
+ block.append(Normalize(num_channels=self.out_channels))
+
+ block += [
+ nn.Mish(),
+ nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False),
+ ]
+ return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the residual forward pass."""
- return x + self.block(x)
+ residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x
+ return residual + self.block(x)
diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py
index 769d089..8d67d02 100644
--- a/text_recognizer/networks/vqvae/resize.py
+++ b/text_recognizer/networks/vqvae/resize.py
@@ -8,7 +8,7 @@ class Upsample(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Applies upsampling."""
- return F.interpolate(x, scale_factor=2, mode="nearest")
+ return F.interpolate(x, scale_factor=2.0, mode="nearest")
class Downsample(nn.Module):
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 1585d40..0646119 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,13 +1,9 @@
"""The VQ-VAE."""
from typing import Tuple
-import torch
from torch import nn
from torch import Tensor
-import torch.nn.functional as F
-from text_recognizer.networks.vqvae.decoder import Decoder
-from text_recognizer.networks.vqvae.encoder import Encoder
from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
@@ -16,93 +12,45 @@ class VQVAE(nn.Module):
def __init__(
self,
- in_channels: int,
- res_channels: int,
- num_residual_layers: int,
+ encoder: nn.Module,
+ decoder: nn.Module,
+ hidden_dim: int,
embedding_dim: int,
num_embeddings: int,
decay: float = 0.99,
- activation: str = "mish",
) -> None:
super().__init__()
- # Encoders
- self.btm_encoder = Encoder(
- in_channels=1,
- out_channels=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- embedding_dim=embedding_dim,
- activation=activation,
+ self.encoder = encoder
+ self.decoder = decoder
+ self.pre_codebook_conv = nn.Conv2d(
+ in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1
)
-
- self.top_encoder = Encoder(
- in_channels=embedding_dim,
- out_channels=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- embedding_dim=embedding_dim,
- activation=activation,
+ self.post_codebook_conv = nn.Conv2d(
+ in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1
)
-
- # Quantizers
- self.btm_quantizer = VectorQuantizer(
- num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
- )
-
- self.top_quantizer = VectorQuantizer(
+ self.quantizer = VectorQuantizer(
num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
)
- # Decoders
- self.top_decoder = Decoder(
- in_channels=embedding_dim,
- out_channels=embedding_dim,
- embedding_dim=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- activation=activation,
- )
-
- self.btm_decoder = Decoder(
- in_channels=2 * embedding_dim,
- out_channels=in_channels,
- embedding_dim=embedding_dim,
- res_channels=res_channels,
- num_residual_layers=num_residual_layers,
- activation=activation,
- )
- def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ def encode(self, x: Tensor) -> Tensor:
"""Encodes input to a latent code."""
- z_btm = self.btm_encoder(x)
- z_top = self.top_encoder(z_btm)
- return z_btm, z_top
+ z_e = self.encoder(x)
+ return self.pre_codebook_conv(z_e)
- def quantize(
- self, z_btm: Tensor, z_top: Tensor
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- q_btm, vq_btm_loss = self.top_quantizer(z_btm)
- q_top, vq_top_loss = self.top_quantizer(z_top)
- return q_btm, vq_btm_loss, q_top, vq_top_loss
+ def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
+ z_q, vq_loss = self.quantizer(z_e)
+ return z_q, vq_loss
- def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
+ def decode(self, z_q: Tensor) -> Tensor:
"""Reconstructs input from latent codes."""
- d_top = self.top_decoder(q_top)
- x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
- return d_top, x_hat
-
- def loss_fn(
- self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
- ) -> Tensor:
- """Calculates the latent loss."""
- return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
+ z = self.post_codebook_conv(z_q)
+ x_hat = self.decoder(z)
+ return x_hat
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
- z_btm, z_top = self.encode(x)
- q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
- d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
- vq_loss = self.loss_fn(
- vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
- )
+ z_e = self.encode(x)
+ z_q, vq_loss = self.quantize(z_e)
+ x_hat = self.decode(z_q)
return x_hat, vq_loss