summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/decoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:24:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 23:24:20 +0200
commitdedf8deb025ac9efdad5e9baf9165ef63d6829ff (patch)
tree56b10fcaef479d8abe9b0e6c05e07ad5e02b9ab0 /text_recognizer/networks/vqvae/decoder.py
parent532286b516b17d279c321358bf03dddc8adc8029 (diff)
Pre-commit fixes, optimizer loading fix
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
1 files changed, 15 insertions, 3 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..67ed0d9 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,7 +44,12 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
)
def _build_decompression_block(
@@ -73,7 +78,9 @@ class Decoder(nn.Module):
)
if i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ modules.append(
+ nn.Upsample(size=self.upsampling[i]),
+ )
if dropout is not None:
modules.append(dropout)
@@ -102,7 +109,12 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ nn.Conv2d(
+ self.embedding_dim,
+ channels[0],
+ kernel_size=1,
+ stride=1,
+ )
)
# Bottleneck module.