diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/pixelcnn.py')
-rw-r--r-- | text_recognizer/networks/vqvae/pixelcnn.py | 12 |
1 files changed, 4 insertions, 8 deletions
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py index 5c580df..b9e6080 100644 --- a/text_recognizer/networks/vqvae/pixelcnn.py +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -44,7 +44,7 @@ class Encoder(nn.Module): ), ] num_blocks = len(self.channels_multipliers) - in_channels_multipliers = (1,) + self.channels_multipliers + in_channels_multipliers = (1,) + self.channels_multipliers for i in range(num_blocks): in_channels = self.hidden_dim * in_channels_multipliers[i] out_channels = self.hidden_dim * self.channels_multipliers[i] @@ -68,7 +68,7 @@ class Encoder(nn.Module): dropout_rate=self.dropout_rate, use_norm=True, ), - Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]), ] encoder += [ @@ -125,7 +125,7 @@ class Decoder(nn.Module): ), ] - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): @@ -140,11 +140,7 @@ class Decoder(nn.Module): ) ) if i == 0: - decoder.append( - Attention( - in_channels=out_channels - ) - ) + decoder.append(Attention(in_channels=out_channels)) decoder.append(Upsample()) decoder += [ |