diff options
| -rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 28 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 27 | 
2 files changed, 39 insertions, 16 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 63eac13..7734a5a 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -20,6 +20,8 @@ class Decoder(nn.Module):          dropout_rate: float,          activation: str = "mish",          use_norm: bool = False, +        num_residuals: int = 4, +        residual_channels: int = 32,      ) -> None:          super().__init__()          self.out_channels = out_channels @@ -28,18 +30,20 @@ class Decoder(nn.Module):          self.activation = activation          self.dropout_rate = dropout_rate          self.use_norm = use_norm +        self.num_residuals = num_residuals +        self.residual_channels = residual_channels          self.decoder = self._build_decompression_block()      def _build_decompression_block(self,) -> nn.Sequential: -        in_channels = self.hidden_dim * self.channels_multipliers[0]          decoder = [] -        for _ in range(4): +        in_channels = self.hidden_dim * self.channels_multipliers[0] +        for _ in range(self.num_residuals):              decoder += [                  Residual(                      in_channels=in_channels, -                    out_channels=in_channels, -                    dropout_rate=self.dropout_rate, +                    residual_channels=self.residual_channels,                      use_norm=self.use_norm, +                    activation=self.activation,                  ),              ] @@ -50,7 +54,12 @@ class Decoder(nn.Module):          for i in range(num_blocks):              in_channels = self.hidden_dim * self.channels_multipliers[i]              out_channels = self.hidden_dim * out_channels_multipliers[i + 1] +            if self.use_norm: +                decoder += [ +                    Normalize(num_channels=in_channels,), +                ]              decoder += [ +                activation_fn,                  nn.ConvTranspose2d(                      in_channels=in_channels,                      out_channels=out_channels, @@ -58,12 +67,17 @@ class Decoder(nn.Module):                      stride=2,                      padding=1,                  ), -                activation_fn, +            ] + +        if self.use_norm: +            decoder += [ +                Normalize( +                    num_channels=self.hidden_dim * out_channels_multipliers[-1], +                    num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4, +                ),              ]          decoder += [ -            Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), -            activation_fn,              nn.Conv2d(                  in_channels=self.hidden_dim * out_channels_multipliers[-1],                  out_channels=self.out_channels, diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index b8179f0..4761486 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -5,6 +5,7 @@ from torch import nn  from torch import Tensor  from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.norm import Normalize  from text_recognizer.networks.vqvae.residual import Residual @@ -19,6 +20,8 @@ class Encoder(nn.Module):          dropout_rate: float,          activation: str = "mish",          use_norm: bool = False, +        num_residuals: int = 4, +        residual_channels: int = 32,      ) -> None:          super().__init__()          self.in_channels = in_channels @@ -27,10 +30,16 @@ class Encoder(nn.Module):          self.activation = activation          self.dropout_rate = dropout_rate          self.use_norm = use_norm +        self.num_residuals = num_residuals +        self.residual_channels = residual_channels          self.encoder = self._build_compression_block()      def _build_compression_block(self) -> nn.Sequential:          """Builds encoder network.""" +        num_blocks = len(self.channels_multipliers) +        channels_multipliers = (1,) + self.channels_multipliers +        activation_fn = activation_function(self.activation) +          encoder = [              nn.Conv2d(                  in_channels=self.in_channels, @@ -41,14 +50,15 @@ class Encoder(nn.Module):              ),          ] -        num_blocks = len(self.channels_multipliers) -        channels_multipliers = (1,) + self.channels_multipliers -        activation_fn = activation_function(self.activation) -          for i in range(num_blocks):              in_channels = self.hidden_dim * channels_multipliers[i]              out_channels = self.hidden_dim * channels_multipliers[i + 1] +            if self.use_norm: +                encoder += [ +                    Normalize(num_channels=in_channels,), +                ]              encoder += [ +                activation_fn,                  nn.Conv2d(                      in_channels=in_channels,                      out_channels=out_channels, @@ -56,16 +66,15 @@ class Encoder(nn.Module):                      stride=2,                      padding=1,                  ), -                activation_fn,              ] -        for _ in range(4): +        for _ in range(self.num_residuals):              encoder += [                  Residual( -                    in_channels=self.hidden_dim * self.channels_multipliers[-1], -                    out_channels=self.hidden_dim * self.channels_multipliers[-1], -                    dropout_rate=self.dropout_rate, +                    in_channels=out_channels, +                    residual_channels=self.residual_channels,                      use_norm=self.use_norm, +                    activation=self.activation,                  )              ]  |