summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/encoder.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:27 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:27 +0200
commitf65b8a48763a6163083b84ddc7a65d33c091adf7 (patch)
treedec06cf24cf7f654103d1b290275efa65ce820ca /text_recognizer/networks/vqvae/encoder.py
parente115c251a2a14508cb2f234d31bc3a6eb5cc2392 (diff)
Add num res blocks as a variable
Diffstat (limited to 'text_recognizer/networks/vqvae/encoder.py')
-rw-r--r--text_recognizer/networks/vqvae/encoder.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index b8179f0..4761486 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -5,6 +5,7 @@ from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.norm import Normalize
from text_recognizer.networks.vqvae.residual import Residual
@@ -19,6 +20,8 @@ class Encoder(nn.Module):
dropout_rate: float,
activation: str = "mish",
use_norm: bool = False,
+ num_residuals: int = 4,
+ residual_channels: int = 32,
) -> None:
super().__init__()
self.in_channels = in_channels
@@ -27,10 +30,16 @@ class Encoder(nn.Module):
self.activation = activation
self.dropout_rate = dropout_rate
self.use_norm = use_norm
+ self.num_residuals = num_residuals
+ self.residual_channels = residual_channels
self.encoder = self._build_compression_block()
def _build_compression_block(self) -> nn.Sequential:
"""Builds encoder network."""
+ num_blocks = len(self.channels_multipliers)
+ channels_multipliers = (1,) + self.channels_multipliers
+ activation_fn = activation_function(self.activation)
+
encoder = [
nn.Conv2d(
in_channels=self.in_channels,
@@ -41,14 +50,15 @@ class Encoder(nn.Module):
),
]
- num_blocks = len(self.channels_multipliers)
- channels_multipliers = (1,) + self.channels_multipliers
- activation_fn = activation_function(self.activation)
-
for i in range(num_blocks):
in_channels = self.hidden_dim * channels_multipliers[i]
out_channels = self.hidden_dim * channels_multipliers[i + 1]
+ if self.use_norm:
+ encoder += [
+ Normalize(num_channels=in_channels,),
+ ]
encoder += [
+ activation_fn,
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
@@ -56,16 +66,15 @@ class Encoder(nn.Module):
stride=2,
padding=1,
),
- activation_fn,
]
- for _ in range(4):
+ for _ in range(self.num_residuals):
encoder += [
Residual(
- in_channels=self.hidden_dim * self.channels_multipliers[-1],
- out_channels=self.hidden_dim * self.channels_multipliers[-1],
- dropout_rate=self.dropout_rate,
+ in_channels=out_channels,
+ residual_channels=self.residual_channels,
use_norm=self.use_norm,
+ activation=self.activation,
)
]