From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 6 Aug 2021 02:42:45 +0200 Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading --- text_recognizer/networks/conv_transformer.py | 7 +- text_recognizer/networks/vq_transformer.py | 50 ++++---- text_recognizer/networks/vqvae/__init__.py | 1 - text_recognizer/networks/vqvae/attention.py | 7 +- text_recognizer/networks/vqvae/decoder.py | 83 ++++++++------ text_recognizer/networks/vqvae/encoder.py | 82 ++++++------- text_recognizer/networks/vqvae/norm.py | 4 +- text_recognizer/networks/vqvae/pixelcnn.py | 165 +++++++++++++++++++++++++++ text_recognizer/networks/vqvae/quantizer.py | 15 ++- text_recognizer/networks/vqvae/residual.py | 53 +++++++-- text_recognizer/networks/vqvae/resize.py | 2 +- text_recognizer/networks/vqvae/vqvae.py | 98 ++++------------ 12 files changed, 359 insertions(+), 208 deletions(-) create mode 100644 text_recognizer/networks/vqvae/pixelcnn.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index f3ba49d..b1a101e 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -4,7 +4,6 @@ from typing import Tuple from torch import nn, Tensor -from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, @@ -18,15 +17,17 @@ class ConvTransformer(nn.Module): def __init__( self, input_dims: Tuple[int, int, int], + encoder_dim: int, hidden_dim: int, dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: EfficientNet, + encoder: nn.Module, decoder: Decoder, ) -> None: super().__init__() self.input_dims = input_dims + self.encoder_dim = encoder_dim self.hidden_dim = hidden_dim self.dropout_rate = dropout_rate self.num_classes = num_classes @@ -38,7 +39,7 @@ class ConvTransformer(nn.Module): # positional encoding. self.latent_encoder = nn.Sequential( nn.Conv2d( - in_channels=self.encoder.out_channels, + in_channels=self.encoder_dim, out_channels=self.hidden_dim, kernel_size=1, ), diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index a972565..0433863 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,16 +1,12 @@ """Vector quantized encoder, transformer decoder.""" -import math from typing import Tuple -from torch import nn, Tensor +import torch +from torch import Tensor -from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.vqvae.vqvae import VQVAE from text_recognizer.networks.conv_transformer import ConvTransformer from text_recognizer.networks.transformer.layers import Decoder -from text_recognizer.networks.transformer.positional_encodings import ( - PositionalEncoding, - PositionalEncoding2D, -) class VqTransformer(ConvTransformer): @@ -19,16 +15,18 @@ class VqTransformer(ConvTransformer): def __init__( self, input_dims: Tuple[int, int, int], + encoder_dim: int, hidden_dim: int, dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: EfficientNet, + encoder: VQVAE, decoder: Decoder, + pretrained_encoder_path: str, ) -> None: - # TODO: Load pretrained vqvae encoder. super().__init__( input_dims=input_dims, + encoder_dim=encoder_dim, hidden_dim=hidden_dim, dropout_rate=dropout_rate, num_classes=num_classes, @@ -36,24 +34,19 @@ class VqTransformer(ConvTransformer): encoder=encoder, decoder=decoder, ) - # Latent projector for down sampling number of filters and 2d - # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), - nn.Flatten(start_dim=2), - ) + self.pretrained_encoder_path = pretrained_encoder_path + + # For typing + self.encoder: VQVAE + + def setup_encoder(self) -> None: + """Remove unecessary layers.""" + # TODO: load pretrained vqvae + del self.encoder.decoder + del self.encoder.post_codebook_conv def encode(self, x: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. + """Encodes an image into a discrete (VQ) latent representation. Args: x (Tensor): Image tensor. @@ -69,8 +62,11 @@ class VqTransformer(ConvTransformer): Returns: Tensor: A Latent embedding of the image. """ - z = self.encoder(x) - z = self.latent_encoder(z) + with torch.no_grad(): + z_e = self.encoder.encode(x) + z_q, _ = self.encoder.quantize(z_e) + + z = self.latent_encoder(z_q) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py index 7d56bdb..e1f05fa 100644 --- a/text_recognizer/networks/vqvae/__init__.py +++ b/text_recognizer/networks/vqvae/__init__.py @@ -1,2 +1 @@ """VQ-VAE module.""" -from .vqvae import VQVAE diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py index 5a6b3ce..78a2cc9 100644 --- a/text_recognizer/networks/vqvae/attention.py +++ b/text_recognizer/networks/vqvae/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from text_recognizer.networks.vqvae.norm import Normalize -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Convolutional attention.""" @@ -63,11 +63,12 @@ class Attention(nn.Module): B, C, H, W = q.shape q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (C ** -0.5) + energy = torch.bmm(q, k) * (int(C) ** -0.5) attention = F.softmax(energy, dim=2) # Compute attention to which values - v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] + v = v.reshape(B, C, H * W) + attention = attention.permute(0, 2, 1) # [B, HW, HW] out = torch.bmm(v, attention) out = out.reshape(B, C, H, W) out = self.proj(out) diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index fcf768b..f51e0a3 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -1,62 +1,69 @@ """CNN decoder for the VQ-VAE.""" -import attr +from typing import Sequence + from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.norm import Normalize from text_recognizer.networks.vqvae.residual import Residual -@attr.s(eq=False) class Decoder(nn.Module): """A CNN encoder network.""" - in_channels: int = attr.ib() - embedding_dim: int = attr.ib() - out_channels: int = attr.ib() - res_channels: int = attr.ib() - num_residual_layers: int = attr.ib() - activation: str = attr.ib() - decoder: nn.Sequential = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: super().__init__() + self.out_channels = out_channels + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.activation = activation + self.dropout_rate = dropout_rate self.decoder = self._build_decompression_block() def _build_decompression_block(self,) -> nn.Sequential: + in_channels = self.hidden_dim * self.channels_multipliers[0] + decoder = [] + for _ in range(2): + decoder += [ + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + activation_fn = activation_function(self.activation) - blocks = [ + out_channels_multipliers = self.channels_multipliers + (1, ) + num_blocks = len(self.channels_multipliers) + + for i in range(num_blocks): + in_channels = self.hidden_dim * self.channels_multipliers[i] + out_channels = self.hidden_dim * out_channels_multipliers[i + 1] + decoder += [ + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + ] + + decoder += [ + Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), + nn.Mish(), nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.embedding_dim, - kernel_size=3, - padding=1, - ) - ] - for _ in range(self.num_residual_layers): - blocks.append( - Residual(in_channels=self.embedding_dim, out_channels=self.res_channels) - ) - blocks.append(activation_fn) - blocks += [ - nn.ConvTranspose2d( - in_channels=self.embedding_dim, - out_channels=self.embedding_dim // 2, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.ConvTranspose2d( - in_channels=self.embedding_dim // 2, + in_channels=self.hidden_dim * out_channels_multipliers[-1], out_channels=self.out_channels, - kernel_size=4, - stride=2, + kernel_size=3, + stride=1, padding=1, ), ] - return nn.Sequential(*blocks) + return nn.Sequential(*decoder) def forward(self, z_q: Tensor) -> Tensor: """Reconstruct input from given codes.""" diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index f086c6b..ad8f950 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,7 +1,6 @@ """CNN encoder for the VQ-VAE.""" -from typing import Sequence, Optional, Tuple, Type +from typing import List, Tuple -import attr from torch import nn from torch import Tensor @@ -9,64 +8,59 @@ from text_recognizer.networks.util import activation_function from text_recognizer.networks.vqvae.residual import Residual -@attr.s(eq=False) class Encoder(nn.Module): """A CNN encoder network.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() - res_channels: int = attr.ib() - num_residual_layers: int = attr.ib() - embedding_dim: int = attr.ib() - activation: str = attr.ib() - encoder: nn.Sequential = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: super().__init__() + self.in_channels = in_channels + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.activation = activation + self.dropout_rate = dropout_rate self.encoder = self._build_compression_block() def _build_compression_block(self) -> nn.Sequential: - activation_fn = activation_function(self.activation) - block = [ + """Builds encoder network.""" + encoder = [ nn.Conv2d( in_channels=self.in_channels, - out_channels=self.out_channels // 2, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.Conv2d( - in_channels=self.out_channels // 2, - out_channels=self.out_channels, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.Conv2d( - in_channels=self.out_channels, - out_channels=self.out_channels, + out_channels=self.hidden_dim, kernel_size=3, + stride=1, padding=1, ), ] - for _ in range(self.num_residual_layers): - block.append( - Residual(in_channels=self.out_channels, out_channels=self.res_channels) - ) + num_blocks = len(self.channels_multipliers) + channels_multipliers = (1, ) + self.channels_multipliers + activation_fn = activation_function(self.activation) - block.append( - nn.Conv2d( - in_channels=self.out_channels, - out_channels=self.embedding_dim, - kernel_size=1, - ) - ) + for i in range(num_blocks): + in_channels = self.hidden_dim * channels_multipliers[i] + out_channels = self.hidden_dim * channels_multipliers[i + 1] + encoder += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + ] + + for _ in range(2): + encoder += [ + Residual( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + dropout_rate=self.dropout_rate, + use_norm=True, + ) + ] - return nn.Sequential(*block) + return nn.Sequential(*encoder) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input into a discrete representation.""" diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py index df66efc..3e6963a 100644 --- a/text_recognizer/networks/vqvae/norm.py +++ b/text_recognizer/networks/vqvae/norm.py @@ -3,7 +3,7 @@ import attr from torch import nn, Tensor -@attr.s +@attr.s(eq=False) class Normalize(nn.Module): num_channels: int = attr.ib() norm: nn.GroupNorm = attr.ib(init=False) @@ -12,7 +12,7 @@ class Normalize(nn.Module): """Post init configuration.""" super().__init__() self.norm = nn.GroupNorm( - num_groups=32, num_channels=self.num_channels, eps=1.0e-6, affine=True + num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True ) def forward(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py new file mode 100644 index 0000000..5c580df --- /dev/null +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -0,0 +1,165 @@ +"""PixelCNN encoder and decoder. + +Same as in VQGAN among other. Hopefully, better reconstructions... + +TODO: Add num of residual layers. +""" +from typing import Sequence + +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae.attention import Attention +from text_recognizer.networks.vqvae.norm import Normalize +from text_recognizer.networks.vqvae.residual import Residual +from text_recognizer.networks.vqvae.resize import Downsample, Upsample + + +class Encoder(nn.Module): + """PixelCNN encoder.""" + + def __init__( + self, + in_channels: int, + hidden_dim: int, + channels_multipliers: Sequence[int], + dropout_rate: float, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.dropout_rate = dropout_rate + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.encoder = self._build_encoder() + + def _build_encoder(self) -> nn.Sequential: + """Builds encoder.""" + encoder = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.hidden_dim, + kernel_size=3, + stride=1, + padding=1, + ), + ] + num_blocks = len(self.channels_multipliers) + in_channels_multipliers = (1,) + self.channels_multipliers + for i in range(num_blocks): + in_channels = self.hidden_dim * in_channels_multipliers[i] + out_channels = self.hidden_dim * self.channels_multipliers[i] + encoder += [ + Residual( + in_channels=in_channels, + out_channels=out_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + if i == num_blocks - 1: + encoder.append(Attention(in_channels=out_channels)) + encoder.append(Downsample()) + + for _ in range(2): + encoder += [ + Residual( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + dropout_rate=self.dropout_rate, + use_norm=True, + ), + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + ] + + encoder += [ + Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]), + nn.Mish(), + nn.Conv2d( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + kernel_size=3, + stride=1, + padding=1, + ), + ] + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tensor: + """Encodes input to a latent representation.""" + return self.encoder(x) + + +class Decoder(nn.Module): + """PixelCNN decoder.""" + + def __init__( + self, + hidden_dim: int, + channels_multipliers: Sequence[int], + out_channels: int, + dropout_rate: float, + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.out_channels = out_channels + self.channels_multipliers = tuple(channels_multipliers) + self.dropout_rate = dropout_rate + self.decoder = self._build_decoder() + + def _build_decoder(self) -> nn.Sequential: + """Builds decoder.""" + in_channels = self.hidden_dim * self.channels_multipliers[0] + decoder = [ + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + Attention(in_channels=in_channels), + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + + out_channels_multipliers = self.channels_multipliers + (1, ) + num_blocks = len(self.channels_multipliers) + + for i in range(num_blocks): + in_channels = self.hidden_dim * self.channels_multipliers[i] + out_channels = self.hidden_dim * out_channels_multipliers[i + 1] + decoder.append( + Residual( + in_channels=in_channels, + out_channels=out_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ) + ) + if i == 0: + decoder.append( + Attention( + in_channels=out_channels + ) + ) + decoder.append(Upsample()) + + decoder += [ + Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), + nn.Mish(), + nn.Conv2d( + in_channels=self.hidden_dim * out_channels_multipliers[-1], + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + ), + ] + return nn.Sequential(*decoder) + + def forward(self, x: Tensor) -> Tensor: + """Decodes latent vector.""" + return self.decoder(x) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index a4f11f0..6fb57e8 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -11,13 +11,15 @@ import torch.nn.functional as F class EmbeddingEMA(nn.Module): + """Embedding for Exponential Moving Average (EMA).""" + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: super().__init__() weight = torch.zeros(num_embeddings, embedding_dim) nn.init.kaiming_uniform_(weight, nonlinearity="linear") self.register_buffer("weight", weight) - self.register_buffer("_cluster_size", torch.zeros(num_embeddings)) - self.register_buffer("_weight_avg", weight) + self.register_buffer("cluster_size", torch.zeros(num_embeddings)) + self.register_buffer("weight_avg", weight.clone()) class VectorQuantizer(nn.Module): @@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module): return quantized_latent def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() - self.embedding._cluster_size.data.mul_(self.decay).add_( + self.embedding.cluster_size.data.mul_(self.decay).add_( batch_cluster_size, alpha=1 - self.decay ) - self.embedding._weight_avg.data.mul_(self.decay).add_( + self.embedding.weight_avg.data.mul_(self.decay).add_( batch_embedding_avg, alpha=1 - self.decay ) - new_embedding = self.embedding._weight_avg / ( - self.embedding._cluster_size + 1.0e-5 + new_embedding = self.embedding.weight_avg / ( + self.embedding.cluster_size + 1.0e-5 ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 98109b8..4ed3781 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -1,18 +1,55 @@ """Residual block.""" +import attr from torch import nn from torch import Tensor +from text_recognizer.networks.vqvae.norm import Normalize + +@attr.s(eq=False) class Residual(nn.Module): - def __init__(self, in_channels: int, out_channels: int,) -> None: + in_channels: int = attr.ib() + out_channels: int = attr.ib() + dropout_rate: float = attr.ib(default=0.0) + use_norm: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" super().__init__() - self.block = nn.Sequential( - nn.Mish(inplace=True), - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.Mish(inplace=True), - nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), - ) + self.block = self._build_res_block() + if self.in_channels != self.out_channels: + self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_shortcut = None + + def _build_res_block(self) -> nn.Sequential: + """Build residual block.""" + block = [] + if self.use_norm: + block.append(Normalize(num_channels=self.in_channels)) + block += [ + nn.Mish(), + nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + ] + if self.dropout_rate: + block += [nn.Dropout(p=self.dropout_rate)] + + if self.use_norm: + block.append(Normalize(num_channels=self.out_channels)) + + block += [ + nn.Mish(), + nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False), + ] + return nn.Sequential(*block) def forward(self, x: Tensor) -> Tensor: """Apply the residual forward pass.""" - return x + self.block(x) + residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x + return residual + self.block(x) diff --git a/text_recognizer/networks/vqvae/resize.py b/text_recognizer/networks/vqvae/resize.py index 769d089..8d67d02 100644 --- a/text_recognizer/networks/vqvae/resize.py +++ b/text_recognizer/networks/vqvae/resize.py @@ -8,7 +8,7 @@ class Upsample(nn.Module): def forward(self, x: Tensor) -> Tensor: """Applies upsampling.""" - return F.interpolate(x, scale_factor=2, mode="nearest") + return F.interpolate(x, scale_factor=2.0, mode="nearest") class Downsample(nn.Module): diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 1585d40..0646119 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,13 +1,9 @@ """The VQ-VAE.""" from typing import Tuple -import torch from torch import nn from torch import Tensor -import torch.nn.functional as F -from text_recognizer.networks.vqvae.decoder import Decoder -from text_recognizer.networks.vqvae.encoder import Encoder from text_recognizer.networks.vqvae.quantizer import VectorQuantizer @@ -16,93 +12,45 @@ class VQVAE(nn.Module): def __init__( self, - in_channels: int, - res_channels: int, - num_residual_layers: int, + encoder: nn.Module, + decoder: nn.Module, + hidden_dim: int, embedding_dim: int, num_embeddings: int, decay: float = 0.99, - activation: str = "mish", ) -> None: super().__init__() - # Encoders - self.btm_encoder = Encoder( - in_channels=1, - out_channels=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - embedding_dim=embedding_dim, - activation=activation, + self.encoder = encoder + self.decoder = decoder + self.pre_codebook_conv = nn.Conv2d( + in_channels=hidden_dim, out_channels=embedding_dim, kernel_size=1 ) - - self.top_encoder = Encoder( - in_channels=embedding_dim, - out_channels=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - embedding_dim=embedding_dim, - activation=activation, + self.post_codebook_conv = nn.Conv2d( + in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1 ) - - # Quantizers - self.btm_quantizer = VectorQuantizer( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, - ) - - self.top_quantizer = VectorQuantizer( + self.quantizer = VectorQuantizer( num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - # Decoders - self.top_decoder = Decoder( - in_channels=embedding_dim, - out_channels=embedding_dim, - embedding_dim=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - activation=activation, - ) - - self.btm_decoder = Decoder( - in_channels=2 * embedding_dim, - out_channels=in_channels, - embedding_dim=embedding_dim, - res_channels=res_channels, - num_residual_layers=num_residual_layers, - activation=activation, - ) - def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def encode(self, x: Tensor) -> Tensor: """Encodes input to a latent code.""" - z_btm = self.btm_encoder(x) - z_top = self.top_encoder(z_btm) - return z_btm, z_top + z_e = self.encoder(x) + return self.pre_codebook_conv(z_e) - def quantize( - self, z_btm: Tensor, z_top: Tensor - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - q_btm, vq_btm_loss = self.top_quantizer(z_btm) - q_top, vq_top_loss = self.top_quantizer(z_top) - return q_btm, vq_btm_loss, q_top, vq_top_loss + def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: + z_q, vq_loss = self.quantizer(z_e) + return z_q, vq_loss - def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]: + def decode(self, z_q: Tensor) -> Tensor: """Reconstructs input from latent codes.""" - d_top = self.top_decoder(q_top) - x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1)) - return d_top, x_hat - - def loss_fn( - self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor - ) -> Tensor: - """Calculates the latent loss.""" - return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm) + z = self.post_codebook_conv(z_q) + x_hat = self.decoder(z) + return x_hat def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" - z_btm, z_top = self.encode(x) - q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top) - d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top) - vq_loss = self.loss_fn( - vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm - ) + z_e = self.encode(x) + z_q, vq_loss = self.quantize(z_e) + x_hat = self.decode(z_q) return x_hat, vq_loss -- cgit v1.2.3-70-g09d2