diff options
Diffstat (limited to 'text_recognizer/networks/vqvae')
| -rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 6 | ||||
| -rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 4 | 
2 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 5279cbd..63eac13 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -19,6 +19,7 @@ class Decoder(nn.Module):          channels_multipliers: Sequence[int],          dropout_rate: float,          activation: str = "mish", +        use_norm: bool = False,      ) -> None:          super().__init__()          self.out_channels = out_channels @@ -26,6 +27,7 @@ class Decoder(nn.Module):          self.channels_multipliers = tuple(channels_multipliers)          self.activation = activation          self.dropout_rate = dropout_rate +        self.use_norm = use_norm          self.decoder = self._build_decompression_block()      def _build_decompression_block(self,) -> nn.Sequential: @@ -37,7 +39,7 @@ class Decoder(nn.Module):                      in_channels=in_channels,                      out_channels=in_channels,                      dropout_rate=self.dropout_rate, -                    use_norm=False, +                    use_norm=self.use_norm,                  ),              ] @@ -61,7 +63,7 @@ class Decoder(nn.Module):          decoder += [              Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), -            nn.Mish(), +            activation_fn,              nn.Conv2d(                  in_channels=self.hidden_dim * out_channels_multipliers[-1],                  out_channels=self.out_channels, diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index fe5ef4b..b8179f0 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -18,6 +18,7 @@ class Encoder(nn.Module):          channels_multipliers: List[int],          dropout_rate: float,          activation: str = "mish", +        use_norm: bool = False,      ) -> None:          super().__init__()          self.in_channels = in_channels @@ -25,6 +26,7 @@ class Encoder(nn.Module):          self.channels_multipliers = tuple(channels_multipliers)          self.activation = activation          self.dropout_rate = dropout_rate +        self.use_norm = use_norm          self.encoder = self._build_compression_block()      def _build_compression_block(self) -> nn.Sequential: @@ -63,7 +65,7 @@ class Encoder(nn.Module):                      in_channels=self.hidden_dim * self.channels_multipliers[-1],                      out_channels=self.hidden_dim * self.channels_multipliers[-1],                      dropout_rate=self.dropout_rate, -                    use_norm=False, +                    use_norm=self.use_norm,                  )              ]  |