diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/decoder.py')
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py | 18 |
1 files changed, 3 insertions, 15 deletions
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 93a1e43..32de912 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -44,12 +44,7 @@ class Decoder(nn.Module): # Configure encoder. self.decoder = self._build_decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, ) def _build_decompression_block( @@ -78,9 +73,7 @@ class Decoder(nn.Module): ) if self.upsampling and i < len(self.upsampling): - modules.append( - nn.Upsample(size=self.upsampling[i]), - ) + modules.append(nn.Upsample(size=self.upsampling[i]),) if dropout is not None: modules.append(dropout) @@ -109,12 +102,7 @@ class Decoder(nn.Module): ) -> nn.Sequential: self.res_block.append( - nn.Conv2d( - self.embedding_dim, - channels[0], - kernel_size=1, - stride=1, - ) + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) ) # Bottleneck module. |