diff options
| -rw-r--r-- | text_recognizer/networks/vqvae/residual.py | 35 | 
1 files changed, 14 insertions, 21 deletions
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 46b091d..bdff9eb 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -3,59 +3,52 @@ import attr  from torch import nn  from torch import Tensor +from text_recognizer.networks.util import activation_function  from text_recognizer.networks.vqvae.norm import Normalize  @attr.s(eq=False)  class Residual(nn.Module):      in_channels: int = attr.ib() -    out_channels: int = attr.ib() -    dropout_rate: float = attr.ib(default=0.0) +    residual_channels: int = attr.ib()      use_norm: bool = attr.ib(default=False) +    activation: str = attr.ib(default="relu")      def __attrs_post_init__(self) -> None:          """Post init configuration."""          super().__init__()          self.block = self._build_res_block() -        if self.in_channels != self.out_channels: -            self.conv_shortcut = nn.Conv2d( -                in_channels=self.in_channels, -                out_channels=self.out_channels, -                kernel_size=3, -                stride=1, -                padding=1, -            ) -        else: -            self.conv_shortcut = None      def _build_res_block(self) -> nn.Sequential:          """Build residual block."""          block = [] +        activation_fn = activation_function(activation=self.activation) +          if self.use_norm:              block.append(Normalize(num_channels=self.in_channels)) +          block += [ -            nn.Mish(), +            activation_fn,              nn.Conv2d(                  self.in_channels, -                self.out_channels, +                self.residual_channels,                  kernel_size=3,                  padding=1,                  bias=False,              ),          ] -        if self.dropout_rate: -            block += [nn.Dropout(p=self.dropout_rate)]          if self.use_norm: -            block.append(Normalize(num_channels=self.out_channels)) +            block.append(Normalize(num_channels=self.residual_channels))          block += [ -            nn.Mish(), -            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False), +            activation_fn, +            nn.Conv2d( +                self.residual_channels, self.in_channels, kernel_size=1, bias=False +            ),          ]          return nn.Sequential(*block)      def forward(self, x: Tensor) -> Tensor:          """Apply the residual forward pass.""" -        residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x -        return residual + self.block(x) +        return x + self.block(x)  |