From 240f5e9f20032e82515fa66ce784619527d1041e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 8 Aug 2021 19:59:55 +0200 Subject: Add VQGAN and loss function --- text_recognizer/networks/vqvae/pixelcnn.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'text_recognizer/networks/vqvae/pixelcnn.py') diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py index 5c580df..b9e6080 100644 --- a/text_recognizer/networks/vqvae/pixelcnn.py +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -44,7 +44,7 @@ class Encoder(nn.Module): ), ] num_blocks = len(self.channels_multipliers) - in_channels_multipliers = (1,) + self.channels_multipliers + in_channels_multipliers = (1,) + self.channels_multipliers for i in range(num_blocks): in_channels = self.hidden_dim * in_channels_multipliers[i] out_channels = self.hidden_dim * self.channels_multipliers[i] @@ -68,7 +68,7 @@ class Encoder(nn.Module): dropout_rate=self.dropout_rate, use_norm=True, ), - Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]), ] encoder += [ @@ -125,7 +125,7 @@ class Decoder(nn.Module): ), ] - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): @@ -140,11 +140,7 @@ class Decoder(nn.Module): ) ) if i == 0: - decoder.append( - Attention( - in_channels=out_channels - ) - ) + decoder.append(Attention(in_channels=out_channels)) decoder.append(Upsample()) decoder += [ -- cgit v1.2.3-70-g09d2