From 9426cc794d8c28a65bbbf5ae5466a0a343078558 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 25 Apr 2021 23:32:50 +0200
Subject: Efficient net and non working transformer model.

---
 text_recognizer/networks/vqvae/decoder.py | 18 +++---------------
 text_recognizer/networks/vqvae/encoder.py | 12 ++----------
 2 files changed, 5 insertions(+), 25 deletions(-)

(limited to 'text_recognizer/networks/vqvae')

diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 93a1e43..32de912 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -44,12 +44,7 @@ class Decoder(nn.Module):
 
         # Configure encoder.
         self.decoder = self._build_decoder(
-            channels,
-            kernel_sizes,
-            strides,
-            num_residual_layers,
-            activation,
-            dropout,
+            channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
         )
 
     def _build_decompression_block(
@@ -78,9 +73,7 @@ class Decoder(nn.Module):
             )
 
             if self.upsampling and i < len(self.upsampling):
-                modules.append(
-                    nn.Upsample(size=self.upsampling[i]),
-                )
+                modules.append(nn.Upsample(size=self.upsampling[i]),)
 
             if dropout is not None:
                 modules.append(dropout)
@@ -109,12 +102,7 @@ class Decoder(nn.Module):
     ) -> nn.Sequential:
 
         self.res_block.append(
-            nn.Conv2d(
-                self.embedding_dim,
-                channels[0],
-                kernel_size=1,
-                stride=1,
-            )
+            nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
         )
 
         # Bottleneck module.
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index b0cceed..65801df 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
 
 class _ResidualBlock(nn.Module):
     def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        dropout: Optional[Type[nn.Module]],
+        self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
     ) -> None:
         super().__init__()
         self.block = [
@@ -138,12 +135,7 @@ class Encoder(nn.Module):
         )
 
         encoder.append(
-            nn.Conv2d(
-                channels[-1],
-                self.embedding_dim,
-                kernel_size=1,
-                stride=1,
-            )
+            nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
         )
 
         return nn.Sequential(*encoder)
-- 
cgit v1.2.3-70-g09d2