summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index f51e0a3..fcbed57 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -12,7 +12,14 @@ from text_recognizer.networks.vqvae.residual import Residual
class Decoder(nn.Module):
"""A CNN encoder network."""
- def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None:
+ def __init__(
+ self,
+ out_channels: int,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ dropout_rate: float,
+ activation: str = "mish",
+ ) -> None:
super().__init__()
self.out_channels = out_channels
self.hidden_dim = hidden_dim
@@ -33,9 +40,9 @@ class Decoder(nn.Module):
use_norm=True,
),
]
-
+
activation_fn = activation_function(self.activation)
- out_channels_multipliers = self.channels_multipliers + (1, )
+ out_channels_multipliers = self.channels_multipliers + (1,)
num_blocks = len(self.channels_multipliers)
for i in range(num_blocks):