1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
"""Residual block."""
from torch import nn
from torch import Tensor
class Residual(nn.Module):
def __init__(self, in_channels: int, out_channels: int,) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Mish(inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.Mish(inplace=True),
nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
)
def forward(self, x: Tensor) -> Tensor:
"""Apply the residual forward pass."""
return x + self.block(x)
|