summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:30 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:30 +0200
commit6dd33e5a087159dcbabb845f167279778b2a8ea5 (patch)
tree72e187988779760ceac8fab0bc3b1142964fb936 /text_recognizer/networks/vqvae
parent99886b4a9664b0716319e54f361091e2bfdf4b8f (diff)
Add ability to set use norm in vqvae
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py6
-rw-r--r--text_recognizer/networks/vqvae/encoder.py4
2 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 5279cbd..63eac13 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -19,6 +19,7 @@ class Decoder(nn.Module):
channels_multipliers: Sequence[int],
dropout_rate: float,
activation: str = "mish",
+ use_norm: bool = False,
) -> None:
super().__init__()
self.out_channels = out_channels
@@ -26,6 +27,7 @@ class Decoder(nn.Module):
self.channels_multipliers = tuple(channels_multipliers)
self.activation = activation
self.dropout_rate = dropout_rate
+ self.use_norm = use_norm
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
@@ -37,7 +39,7 @@ class Decoder(nn.Module):
in_channels=in_channels,
out_channels=in_channels,
dropout_rate=self.dropout_rate,
- use_norm=False,
+ use_norm=self.use_norm,
),
]
@@ -61,7 +63,7 @@ class Decoder(nn.Module):
decoder += [
Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
- nn.Mish(),
+ activation_fn,
nn.Conv2d(
in_channels=self.hidden_dim * out_channels_multipliers[-1],
out_channels=self.out_channels,
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index fe5ef4b..b8179f0 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -18,6 +18,7 @@ class Encoder(nn.Module):
channels_multipliers: List[int],
dropout_rate: float,
activation: str = "mish",
+ use_norm: bool = False,
) -> None:
super().__init__()
self.in_channels = in_channels
@@ -25,6 +26,7 @@ class Encoder(nn.Module):
self.channels_multipliers = tuple(channels_multipliers)
self.activation = activation
self.dropout_rate = dropout_rate
+ self.use_norm = use_norm
self.encoder = self._build_compression_block()
def _build_compression_block(self) -> nn.Sequential:
@@ -63,7 +65,7 @@ class Encoder(nn.Module):
in_channels=self.hidden_dim * self.channels_multipliers[-1],
out_channels=self.hidden_dim * self.channels_multipliers[-1],
dropout_rate=self.dropout_rate,
- use_norm=False,
+ use_norm=self.use_norm,
)
]