summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/mbconv.py
blob: 7bfd9ba2d278e0a91d69183a07d04e3abe174459 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Mobile inverted residual block."""
from typing import Optional, Sequence, Union, Tuple

import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F

from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth


def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
    """Converts int to tuple."""
    return (stride,) * 2 if isinstance(stride, int) else stride


@attr.s(eq=False)
class MBConvBlock(nn.Module):
    """Mobile Inverted Residual Bottleneck block."""

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    in_channels: int = attr.ib()
    out_channels: int = attr.ib()
    kernel_size: Tuple[int, int] = attr.ib()
    stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
    bn_momentum: float = attr.ib()
    bn_eps: float = attr.ib()
    se_ratio: float = attr.ib()
    expand_ratio: int = attr.ib()
    pad: Tuple[int, int, int, int] = attr.ib(init=False)
    _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
    _depthwise: nn.Sequential = attr.ib(init=False)
    _squeeze_excite: nn.Sequential = attr.ib(init=False)
    _pointwise: nn.Sequential = attr.ib(init=False)

    @pad.default
    def _configure_padding(self) -> Tuple[int, int, int, int]:
        """Set padding for convolutional layers."""
        if self.stride == (2, 2):
            return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
        return ((self.kernel_size - 1) // 2,) * 4

    def __attrs_post_init__(self) -> None:
        """Post init configuration."""
        self._build()

    def _build(self) -> None:
        has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
        inner_channels = self.in_channels * self.expand_ratio
        self._inverted_bottleneck = (
            self._configure_inverted_bottleneck(out_channels=inner_channels)
            if self.expand_ratio != 1
            else None
        )

        self._depthwise = self._configure_depthwise(
            in_channels=inner_channels,
            out_channels=inner_channels,
            groups=inner_channels,
        )

        self._squeeze_excite = (
            self._configure_squeeze_excite(
                in_channels=inner_channels, out_channels=inner_channels,
            )
            if has_se
            else None
        )

        self._pointwise = self._configure_pointwise(in_channels=inner_channels)

    def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
        """Expansion phase."""
        return nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=out_channels,
                kernel_size=1,
                bias=False,
            ),
            nn.BatchNorm2d(
                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
            ),
            nn.Mish(inplace=True),
        )

    def _configure_depthwise(
        self, in_channels: int, out_channels: int, groups: int,
    ) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                groups=groups,
                bias=False,
            ),
            nn.BatchNorm2d(
                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
            ),
            nn.Mish(inplace=True),
        )

    def _configure_squeeze_excite(
        self, in_channels: int, out_channels: int
    ) -> nn.Sequential:
        num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=num_squeezed_channels,
                kernel_size=1,
            ),
            nn.Mish(inplace=True),
            nn.Conv2d(
                in_channels=num_squeezed_channels,
                out_channels=out_channels,
                kernel_size=1,
            ),
        )

    def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=self.out_channels,
                kernel_size=1,
                bias=False,
            ),
            nn.BatchNorm2d(
                num_features=self.out_channels,
                momentum=self.bn_momentum,
                eps=self.bn_eps,
            ),
        )

    def _stochastic_depth(
        self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
    ) -> Tensor:
        if self.stride == (1, 1) and self.in_channels == self.out_channels:
            if stochastic_dropout_rate:
                x = stochastic_depth(
                    x, p=stochastic_dropout_rate, training=self.training
                )
            x += residual
        return x

    def forward(
        self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
    ) -> Tensor:
        residual = x
        if self._inverted_bottleneck is not None:
            x = self._inverted_bottleneck(x)

        x = F.pad(x, self.pad)
        x = self._depthwise(x)

        if self._squeeze_excite is not None:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._squeeze_excite(x)
            x = torch.tanh(F.softplus(x_squeezed)) * x

        x = self._pointwise(x)

        # Stochastic depth
        x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
        return x