summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/pixelcnn.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
commit240f5e9f20032e82515fa66ce784619527d1041e (patch)
treeb002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/networks/vqvae/pixelcnn.py
parentd12f70402371dda586d457af2a3df7fb5b3130ad (diff)
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/networks/vqvae/pixelcnn.py')
-rw-r--r--text_recognizer/networks/vqvae/pixelcnn.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
index 5c580df..b9e6080 100644
--- a/text_recognizer/networks/vqvae/pixelcnn.py
+++ b/text_recognizer/networks/vqvae/pixelcnn.py
@@ -44,7 +44,7 @@ class Encoder(nn.Module):
),
]
num_blocks = len(self.channels_multipliers)
- in_channels_multipliers = (1,) + self.channels_multipliers
+ in_channels_multipliers = (1,) + self.channels_multipliers
for i in range(num_blocks):
in_channels = self.hidden_dim * in_channels_multipliers[i]
out_channels = self.hidden_dim * self.channels_multipliers[i]
@@ -68,7 +68,7 @@ class Encoder(nn.Module):
dropout_rate=self.dropout_rate,
use_norm=True,
),
- Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1])
+ Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]),
]
encoder += [
@@ -125,7 +125,7 @@ class Decoder(nn.Module):
),
]
- out_channels_multipliers = self.channels_multipliers + (1, )
+ out_channels_multipliers = self.channels_multipliers + (1,)
num_blocks = len(self.channels_multipliers)
for i in range(num_blocks):
@@ -140,11 +140,7 @@ class Decoder(nn.Module):
)
)
if i == 0:
- decoder.append(
- Attention(
- in_channels=out_channels
- )
- )
+ decoder.append(Attention(in_channels=out_channels))
decoder.append(Upsample())
decoder += [