diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
commit | d3afa310f77f47553586eeee58e3d3345a754e2c (patch) | |
tree | 08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /text_recognizer/networks/vqvae/residual.py | |
parent | 65d5f6c694e73792e40ed693a1381a792da8d277 (diff) |
New VQVAE
Diffstat (limited to 'text_recognizer/networks/vqvae/residual.py')
-rw-r--r-- | text_recognizer/networks/vqvae/residual.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py new file mode 100644 index 0000000..98109b8 --- /dev/null +++ b/text_recognizer/networks/vqvae/residual.py @@ -0,0 +1,18 @@ +"""Residual block.""" +from torch import nn +from torch import Tensor + + +class Residual(nn.Module): + def __init__(self, in_channels: int, out_channels: int,) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Mish(inplace=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.Mish(inplace=True), + nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) |