diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/decoder.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 83 |
1 files changed, 45 insertions, 38 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index fcf768b..f51e0a3 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -1,62 +1,69 @@ """CNN decoder for the VQ-VAE.""" -import attr +from typing import Sequence + from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.norm import Normalize from text_recognizer.networks.vqvae.residual import Residual -@attr.s(eq=False) class Decoder(nn.Module): """A CNN encoder network.""" - in_channels: int = attr.ib() - embedding_dim: int = attr.ib() - out_channels: int = attr.ib() - res_channels: int = attr.ib() - num_residual_layers: int = attr.ib() - activation: str = attr.ib() - decoder: nn.Sequential = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: super().__init__() + self.out_channels = out_channels + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.activation = activation + self.dropout_rate = dropout_rate self.decoder = self._build_decompression_block() def _build_decompression_block(self,) -> nn.Sequential: + in_channels = self.hidden_dim * self.channels_multipliers[0] + decoder = [] + for _ in range(2): + decoder += [ + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + activation_fn = activation_function(self.activation) - blocks = [ + out_channels_multipliers = self.channels_multipliers + (1, ) + num_blocks = len(self.channels_multipliers) + + for i in range(num_blocks): + in_channels = self.hidden_dim * self.channels_multipliers[i] + out_channels = self.hidden_dim * out_channels_multipliers[i + 1] + decoder += [ + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + ] + + decoder += [ + Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), + nn.Mish(), nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.embedding_dim, - kernel_size=3, - padding=1, - ) - ] - for _ in range(self.num_residual_layers): - blocks.append( - Residual(in_channels=self.embedding_dim, out_channels=self.res_channels) - ) - blocks.append(activation_fn) - blocks += [ - nn.ConvTranspose2d( - in_channels=self.embedding_dim, - out_channels=self.embedding_dim // 2, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.ConvTranspose2d( - in_channels=self.embedding_dim // 2, + in_channels=self.hidden_dim * out_channels_multipliers[-1], out_channels=self.out_channels, - kernel_size=4, - stride=2, + kernel_size=3, + stride=1, padding=1, ), ] - return nn.Sequential(*blocks) + return nn.Sequential(*decoder) def forward(self, z_q: Tensor) -> Tensor: """Reconstruct input from given codes.""" |