summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-07 22:12:10 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-07 22:12:10 +0200
commit8afa8e1c6e9623b0dea86236da04b2b4173e9443 (patch)
tree4c9462507b3b3076aa26f08ab629f64b90aed2cb /text_recognizer/networks/vqvae/encoder.py
parent33190bc9c0c377edab280efe4b0bd0e53bb6cb00 (diff)
Fixed typing and typos, train script load config, reformatted
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
1 files changed, 2 insertions, 10 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index ede5c31..d3adac5 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,10 +11,7 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -138,12 +135,7 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(
- channels[-1],
- self.embedding_dim,
- kernel_size=1,
- stride=1,
- )
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
)
return nn.Sequential(*encoder)