summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r--text_recognizer/networks/vqvae/decoder.py18
1 files changed, 3 insertions, 15 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 93a1e43..32de912 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,12 +44,7 @@ class Decoder(nn.Module):
# Configure encoder.
self.decoder = self._build_decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
)
def _build_decompression_block(
@@ -78,9 +73,7 @@ class Decoder(nn.Module):
)
if self.upsampling and i < len(self.upsampling):
- modules.append(
- nn.Upsample(size=self.upsampling[i]),
- )
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
if dropout is not None:
modules.append(dropout)
@@ -109,12 +102,7 @@ class Decoder(nn.Module):
) -> nn.Sequential:
self.res_block.append(
- nn.Conv2d(
- self.embedding_dim,
- channels[0],
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
)
# Bottleneck module.