summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py28
1 files changed, 21 insertions, 7 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 63eac13..7734a5a 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -20,6 +20,8 @@ class Decoder(nn.Module):
dropout_rate: float,
activation: str = "mish",
use_norm: bool = False,
+ num_residuals: int = 4,
+ residual_channels: int = 32,
) -> None:
super().__init__()
self.out_channels = out_channels
@@ -28,18 +30,20 @@ class Decoder(nn.Module):
self.activation = activation
self.dropout_rate = dropout_rate
self.use_norm = use_norm
+ self.num_residuals = num_residuals
+ self.residual_channels = residual_channels
self.decoder = self._build_decompression_block()
def _build_decompression_block(self,) -> nn.Sequential:
- in_channels = self.hidden_dim * self.channels_multipliers[0]
decoder = []
- for _ in range(4):
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ for _ in range(self.num_residuals):
decoder += [
Residual(
in_channels=in_channels,
- out_channels=in_channels,
- dropout_rate=self.dropout_rate,
+ residual_channels=self.residual_channels,
use_norm=self.use_norm,
+ activation=self.activation,
),
]
@@ -50,7 +54,12 @@ class Decoder(nn.Module):
for i in range(num_blocks):
in_channels = self.hidden_dim * self.channels_multipliers[i]
out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ if self.use_norm:
+ decoder += [
+ Normalize(num_channels=in_channels,),
+ ]
decoder += [
+ activation_fn,
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
@@ -58,12 +67,17 @@ class Decoder(nn.Module):
stride=2,
padding=1,
),
- activation_fn,
+ ]
+
+ if self.use_norm:
+ decoder += [
+ Normalize(
+ num_channels=self.hidden_dim * out_channels_multipliers[-1],
+ num_groups=self.hidden_dim * out_channels_multipliers[-1] // 4,
+ ),
]
decoder += [
- Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
- activation_fn,
nn.Conv2d(
in_channels=self.hidden_dim * out_channels_multipliers[-1],
out_channels=self.out_channels,