summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:47 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:47 +0200
commite653bb1ebc6b7fb2e605eff8c64d6402edd38473 (patch)
tree9f210c7a130a5da9871786f52f8ce969d454d7f5 /text_recognizer/networks/vqvae
parentf65b8a48763a6163083b84ddc7a65d33c091adf7 (diff)
Refactor residual block
Diffstat (limited to 'text_recognizer/networks/vqvae')
-rw-r--r--text_recognizer/networks/vqvae/residual.py35
1 files changed, 14 insertions, 21 deletions
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
index 46b091d..bdff9eb 100644
--- a/text_recognizer/networks/vqvae/residual.py
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -3,59 +3,52 @@ import attr
from torch import nn
from torch import Tensor
+from text_recognizer.networks.util import activation_function
from text_recognizer.networks.vqvae.norm import Normalize
@attr.s(eq=False)
class Residual(nn.Module):
in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- dropout_rate: float = attr.ib(default=0.0)
+ residual_channels: int = attr.ib()
use_norm: bool = attr.ib(default=False)
+ activation: str = attr.ib(default="relu")
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
super().__init__()
self.block = self._build_res_block()
- if self.in_channels != self.out_channels:
- self.conv_shortcut = nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- else:
- self.conv_shortcut = None
def _build_res_block(self) -> nn.Sequential:
"""Build residual block."""
block = []
+ activation_fn = activation_function(activation=self.activation)
+
if self.use_norm:
block.append(Normalize(num_channels=self.in_channels))
+
block += [
- nn.Mish(),
+ activation_fn,
nn.Conv2d(
self.in_channels,
- self.out_channels,
+ self.residual_channels,
kernel_size=3,
padding=1,
bias=False,
),
]
- if self.dropout_rate:
- block += [nn.Dropout(p=self.dropout_rate)]
if self.use_norm:
- block.append(Normalize(num_channels=self.out_channels))
+ block.append(Normalize(num_channels=self.residual_channels))
block += [
- nn.Mish(),
- nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False),
+ activation_fn,
+ nn.Conv2d(
+ self.residual_channels, self.in_channels, kernel_size=1, bias=False
+ ),
]
return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the residual forward pass."""
- residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x
- return residual + self.block(x)
+ return x + self.block(x)