diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
commit | 240f5e9f20032e82515fa66ce784619527d1041e (patch) | |
tree | b002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/networks/vqvae/decoder.py | |
parent | d12f70402371dda586d457af2a3df7fb5b3130ad (diff) |
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index f51e0a3..fcbed57 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -12,7 +12,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Decoder(nn.Module): """A CNN encoder network.""" - def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + out_channels: int, + hidden_dim: int, + channels_multipliers: Sequence[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.out_channels = out_channels self.hidden_dim = hidden_dim @@ -33,9 +40,9 @@ class Decoder(nn.Module): use_norm=True, ), ] - + activation_fn = activation_function(self.activation) - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): |