1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
"""Residual block."""
import attr
from torch import nn
from torch import Tensor
from text_recognizer.networks.vqvae.norm import Normalize
@attr.s(eq=False)
class Residual(nn.Module):
in_channels: int = attr.ib()
out_channels: int = attr.ib()
dropout_rate: float = attr.ib(default=0.0)
use_norm: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
super().__init__()
self.block = self._build_res_block()
if self.in_channels != self.out_channels:
self.conv_shortcut = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_shortcut = None
def _build_res_block(self) -> nn.Sequential:
"""Build residual block."""
block = []
if self.use_norm:
block.append(Normalize(num_channels=self.in_channels))
block += [
nn.Mish(),
nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=3,
padding=1,
bias=False,
),
]
if self.dropout_rate:
block += [nn.Dropout(p=self.dropout_rate)]
if self.use_norm:
block.append(Normalize(num_channels=self.out_channels))
block += [
nn.Mish(),
nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, bias=False),
]
return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tensor:
"""Apply the residual forward pass."""
residual = self.conv_shortcut(x) if self.conv_shortcut is not None else x
return residual + self.block(x)
|