summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r--text_recognizer/networks/vqvae/encoder.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..ede5c31 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,7 +11,10 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
@@ -135,7 +138,12 @@ class Encoder(nn.Module):
)
encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ nn.Conv2d(
+ channels[-1],
+ self.embedding_dim,
+ kernel_size=1,
+ stride=1,
+ )
)
return nn.Sequential(*encoder)