summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
commit3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch)
tree136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/decoder.py
parent1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff)
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py83
1 files changed, 45 insertions, 38 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index fcf768b..f51e0a3 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,62 +1,69 @@
"""CNN decoder for the VQ-VAE."""
-import attr
+from typing import Sequence
+
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.norm import Normalize
from text_recognizer.networks.vqvae.residual import Residual
-@attr.s(eq=False)
class Decoder(nn.Module):
"""A CNN encoder network."""
- in_channels: int = attr.ib()
- embedding_dim: int = attr.ib()
- out_channels: int = attr.ib()
- res_channels: int = attr.ib()
- num_residual_layers: int = attr.ib()
- activation: str = attr.ib()
- decoder: nn.Sequential = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None:
super().__init__()
+ self.out_channels = out_channels
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.activation = activation
+ self.dropout_rate = dropout_rate
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = []
+ for _ in range(2):
+ decoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
activation_fn = activation_function(self.activation)
- blocks = [
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder += [
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ ]
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.embedding_dim,
- kernel_size=3,
- padding=1,
- )
- ]
- for _ in range(self.num_residual_layers):
- blocks.append(
- Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
- )
- blocks.append(activation_fn)
- blocks += [
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim,
- out_channels=self.embedding_dim // 2,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- activation_fn,
- nn.ConvTranspose2d(
- in_channels=self.embedding_dim // 2,
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
out_channels=self.out_channels,
- kernel_size=4,
- stride=2,
+ kernel_size=3,
+ stride=1,
padding=1,
),
]
- return nn.Sequential(*blocks)
+ return nn.Sequential(*decoder)
def forward(self, z_q: Tensor) -> Tensor:
"""Reconstruct input from given codes."""