diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:05:30 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:05:30 +0200 |
commit | 6dd33e5a087159dcbabb845f167279778b2a8ea5 (patch) | |
tree | 72e187988779760ceac8fab0bc3b1142964fb936 /text_recognizer/networks | |
parent | 99886b4a9664b0716319e54f361091e2bfdf4b8f (diff) |
Add ability to set use norm in vqvae
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py | 4 |
2 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 5279cbd..63eac13 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -19,6 +19,7 @@ class Decoder(nn.Module): channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish", + use_norm: bool = False, ) -> None: super().__init__() self.out_channels = out_channels @@ -26,6 +27,7 @@ class Decoder(nn.Module): self.channels_multipliers = tuple(channels_multipliers) self.activation = activation self.dropout_rate = dropout_rate + self.use_norm = use_norm self.decoder = self._build_decompression_block() def _build_decompression_block(self,) -> nn.Sequential: @@ -37,7 +39,7 @@ class Decoder(nn.Module): in_channels=in_channels, out_channels=in_channels, dropout_rate=self.dropout_rate, - use_norm=False, + use_norm=self.use_norm, ), ] @@ -61,7 +63,7 @@ class Decoder(nn.Module): decoder += [ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), - nn.Mish(), + activation_fn, nn.Conv2d( in_channels=self.hidden_dim * out_channels_multipliers[-1], out_channels=self.out_channels, diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index fe5ef4b..b8179f0 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -18,6 +18,7 @@ class Encoder(nn.Module): channels_multipliers: List[int], dropout_rate: float, activation: str = "mish", + use_norm: bool = False, ) -> None: super().__init__() self.in_channels = in_channels @@ -25,6 +26,7 @@ class Encoder(nn.Module): self.channels_multipliers = tuple(channels_multipliers) self.activation = activation self.dropout_rate = dropout_rate + self.use_norm = use_norm self.encoder = self._build_compression_block() def _build_compression_block(self) -> nn.Sequential: @@ -63,7 +65,7 @@ class Encoder(nn.Module): in_channels=self.hidden_dim * self.channels_multipliers[-1], out_channels=self.hidden_dim * self.channels_multipliers[-1], dropout_rate=self.dropout_rate, - use_norm=False, + use_norm=self.use_norm, ) ] |