diff options
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 28 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 27 |
2 files changed, 39 insertions, 16 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 63eac13..7734a5a 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -20,6 +20,8 @@ class Decoder(nn.Module): dropout_rate: float, activation: str = "mish", use_norm: bool = False, + num_residuals: int = 4, + residual_channels: int = 32, ) -> None: super().__init__() self.out_channels = out_channels @@ -28,18 +30,20 @@ class Decoder(nn.Module): self.activation = activation self.dropout_rate = dropout_rate self.use_norm = use_norm + self.num_residuals = num_residuals + self.residual_channels = residual_channels self.decoder = self._build_decompression_block() def _build_decompression_block(self,) -> nn.Sequential: - in_channels = self.hidden_dim * self.channels_multipliers[0] decoder = [] - for _ in range(4): + in_channels = self.hidden_dim * self.channels_multipliers[0] + for _ in range(self.num_residuals): decoder += [ Residual( in_channels=in_channels, - out_channels=in_channels, - dropout_rate=self.dropout_rate, + residual_channels=self.residual_channels, use_norm=self.use_norm, + activation=self.activation, ), ] @@ -50,7 +54,12 @@ class Decoder(nn.Module): for i in range(num_blocks): in_channels = self.hidden_dim * self.channels_multipliers[i] out_channels = self.hidden_dim * out_channels_multipliers[i + 1] + if self.use_norm: + decoder += [ + Normalize(num_channels=in_channels,), + ] decoder += [ + activation_fn, nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, @@ -58,12 +67,17 @@ class Decoder(nn.Module): stride=2, padding=1, ), - activation_fn, + ] + + if self.use_norm: + decoder += [ + Normalize( + num_channels=self.hidden_dim * out_channels_multipliers[-1], + num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4, + ), ] decoder += [ - Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), - activation_fn, nn.Conv2d( in_channels=self.hidden_dim * out_channels_multipliers[-1], out_channels=self.out_channels, diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index b8179f0..4761486 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -5,6 +5,7 @@ from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.norm import Normalize from text_recognizer.networks.vqvae.residual import Residual @@ -19,6 +20,8 @@ class Encoder(nn.Module): dropout_rate: float, activation: str = "mish", use_norm: bool = False, + num_residuals: int = 4, + residual_channels: int = 32, ) -> None: super().__init__() self.in_channels = in_channels @@ -27,10 +30,16 @@ class Encoder(nn.Module): self.activation = activation self.dropout_rate = dropout_rate self.use_norm = use_norm + self.num_residuals = num_residuals + self.residual_channels = residual_channels self.encoder = self._build_compression_block() def _build_compression_block(self) -> nn.Sequential: """Builds encoder network.""" + num_blocks = len(self.channels_multipliers) + channels_multipliers = (1,) + self.channels_multipliers + activation_fn = activation_function(self.activation) + encoder = [ nn.Conv2d( in_channels=self.in_channels, @@ -41,14 +50,15 @@ class Encoder(nn.Module): ), ] - num_blocks = len(self.channels_multipliers) - channels_multipliers = (1,) + self.channels_multipliers - activation_fn = activation_function(self.activation) - for i in range(num_blocks): in_channels = self.hidden_dim * channels_multipliers[i] out_channels = self.hidden_dim * channels_multipliers[i + 1] + if self.use_norm: + encoder += [ + Normalize(num_channels=in_channels,), + ] encoder += [ + activation_fn, nn.Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -56,16 +66,15 @@ class Encoder(nn.Module): stride=2, padding=1, ), - activation_fn, ] - for _ in range(4): + for _ in range(self.num_residuals): encoder += [ Residual( - in_channels=self.hidden_dim * self.channels_multipliers[-1], - out_channels=self.hidden_dim * self.channels_multipliers[-1], - dropout_rate=self.dropout_rate, + in_channels=out_channels, + residual_channels=self.residual_channels, use_norm=self.use_norm, + activation=self.activation, ) ] |