summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
commit240f5e9f20032e82515fa66ce784619527d1041e (patch)
treeb002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/networks/vqvae/encoder.py
parentd12f70402371dda586d457af2a3df7fb5b3130ad (diff)
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r--text_recognizer/networks/vqvae/encoder.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index ad8f950..4a5c976 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -11,7 +11,14 @@ from text_recognizer.networks.vqvae.residual import Residual
class Encoder(nn.Module):
"""A CNN encoder network."""
- def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None:
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_dim: int,
+ channels_multipliers: List[int],
+ dropout_rate: float,
+ activation: str = "mish",
+ ) -> None:
super().__init__()
self.in_channels = in_channels
self.hidden_dim = hidden_dim
@@ -33,7 +40,7 @@ class Encoder(nn.Module):
]
num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1, ) + self.channels_multipliers
+ channels_multipliers = (1,) + self.channels_multipliers
activation_fn = activation_function(self.activation)
for i in range(num_blocks):