From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 6 Aug 2021 02:42:45 +0200 Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading --- text_recognizer/networks/vqvae/residual.py | 53 +++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 8 deletions(-) (limited to 'text_recognizer/networks/vqvae/residual.py') diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 98109b8..4ed3781 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -1,18 +1,55 @@ """Residual block.""" +import attr from torch import nn from torch import Tensor +from text_recognizer.networks.vqvae.norm import Normalize + +@attr.s(eq=False) class Residual(nn.Module): - def __init__(self, in_channels: int, out_channels: int,) -> None: + in_channels: int = attr.ib() + out_channels: int = attr.ib() + dropout_rate: float = attr.ib(default=0.0) + use_norm: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" super().__init__() - self.block = nn.Sequential( - nn.Mish(inplace=True), - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.Mish(inplace=True), - nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), - ) + self.block = self._build_res_block() + if self.in_channels != self.out_channels: + self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_shortcut = None + + def _build_res_block(self) -> nn.Sequential: + """Build residual block.""" + block = [] + if self.use_norm: + block.append(Normalize(num_channels=self.in_channels)) + block += [ + nn.Mish(), + nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + ] + if self.dropout_rate: + block += [nn.Dropout(p=self.dropout_rate)] + + if self.use_norm: + block.append(Normalize(num_channels=self.out_channels)) + + block += [ + nn.Mish(), + nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False), + ] + return nn.Sequential(*block) def forward(self, x: Tensor) -> Tensor: """Apply the residual forward pass.""" - return x + self.block(x) + residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x + return residual + self.block(x) -- cgit v1.2.3-70-g09d2