diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/residual.py')
-rw-r--r-- | text_recognizer/networks/vqvae/residual.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py new file mode 100644 index 0000000..98109b8 --- /dev/null +++ b/text_recognizer/networks/vqvae/residual.py @@ -0,0 +1,18 @@ +"""Residual block.""" +from torch import nn +from torch import Tensor + + +class Residual(nn.Module): + def __init__(self, in_channels: int, out_channels: int,) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Mish(inplace=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.Mish(inplace=True), + nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) |