diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/encoder.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 82 |
1 files changed, 38 insertions, 44 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index f086c6b..ad8f950 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,7 +1,6 @@ """CNN encoder for the VQ-VAE.""" -from typing import Sequence, Optional, Tuple, Type +from typing import List, Tuple -import attr from torch import nn from torch import Tensor @@ -9,64 +8,59 @@ from text_recognizer.networks.util import activation_function from text_recognizer.networks.vqvae.residual import Residual -@attr.s(eq=False) class Encoder(nn.Module): """A CNN encoder network.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() - res_channels: int = attr.ib() - num_residual_layers: int = attr.ib() - embedding_dim: int = attr.ib() - activation: str = attr.ib() - encoder: nn.Sequential = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: super().__init__() + self.in_channels = in_channels + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.activation = activation + self.dropout_rate = dropout_rate self.encoder = self._build_compression_block() def _build_compression_block(self) -> nn.Sequential: - activation_fn = activation_function(self.activation) - block = [ + """Builds encoder network.""" + encoder = [ nn.Conv2d( in_channels=self.in_channels, - out_channels=self.out_channels // 2, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.Conv2d( - in_channels=self.out_channels // 2, - out_channels=self.out_channels, - kernel_size=4, - stride=2, - padding=1, - ), - activation_fn, - nn.Conv2d( - in_channels=self.out_channels, - out_channels=self.out_channels, + out_channels=self.hidden_dim, kernel_size=3, + stride=1, padding=1, ), ] - for _ in range(self.num_residual_layers): - block.append( - Residual(in_channels=self.out_channels, out_channels=self.res_channels) - ) + num_blocks = len(self.channels_multipliers) + channels_multipliers = (1, ) + self.channels_multipliers + activation_fn = activation_function(self.activation) - block.append( - nn.Conv2d( - in_channels=self.out_channels, - out_channels=self.embedding_dim, - kernel_size=1, - ) - ) + for i in range(num_blocks): + in_channels = self.hidden_dim * channels_multipliers[i] + out_channels = self.hidden_dim * channels_multipliers[i + 1] + encoder += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + ] + + for _ in range(2): + encoder += [ + Residual( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + dropout_rate=self.dropout_rate, + use_norm=True, + ) + ] - return nn.Sequential(*block) + return nn.Sequential(*encoder) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input into a discrete representation.""" |