diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-07 23:38:48 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-07 23:38:48 +0100 |
commit | 5b7d6eec4160f7b3ea0862915b9428b37a3d49fa (patch) | |
tree | bbcaf670def50f81bb2729b543d9a098cdce1f09 /text_recognizer/networks/encoders/efficientnet | |
parent | f8db5953a63c2467a76cc0610eb2ad0c96b69c70 (diff) |
Fix mbconv sub modules
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 207 |
1 files changed, 139 insertions, 68 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index f01c369..96626fc 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,5 +1,5 @@ """Mobile inverted residual block.""" -from typing import Optional, Sequence, Union, Tuple +from typing import Optional, Tuple, Union import attr import torch @@ -15,119 +15,112 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: @attr.s(eq=False) -class MBConvBlock(nn.Module): - """Mobile Inverted Residual Bottleneck block.""" - - def __attrs_pre_init__(self) -> None: - super().__init__() +class BaseModule(nn.Module): + """Base sub module class.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() - kernel_size: Tuple[int, int] = attr.ib() - stride: Tuple[int, int] = attr.ib(converter=_convert_stride) bn_momentum: float = attr.ib() bn_eps: float = attr.ib() - se_ratio: float = attr.ib() - expand_ratio: int = attr.ib() - pad: Tuple[int, int, int, int] = attr.ib(init=False) - _inverted_bottleneck: nn.Sequential = attr.ib(init=False) - _depthwise: nn.Sequential = attr.ib(init=False) - _squeeze_excite: nn.Sequential = attr.ib(init=False) - _pointwise: nn.Sequential = attr.ib(init=False) + block: nn.Sequential = attr.ib(init=False) - @pad.default - def _configure_padding(self) -> Tuple[int, int, int, int]: - """Set padding for convolutional layers.""" - if self.stride == (2, 2): - return ( - (self.kernel_size - 1) // 2 - 1, - (self.kernel_size - 1) // 2, - ) * 2 - return ((self.kernel_size - 1) // 2,) * 4 + def __attrs_pre_init__(self) -> None: + super().__init__() def __attrs_post_init__(self) -> None: - """Post init configuration.""" self._build() def _build(self) -> None: - has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 - inner_channels = self.in_channels * self.expand_ratio - self._inverted_bottleneck = ( - self._configure_inverted_bottleneck(out_channels=inner_channels) - if self.expand_ratio != 1 - else None - ) + pass - self._depthwise = self._configure_depthwise( - channels=inner_channels, - groups=inner_channels, - ) + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.block(x) - self._squeeze_excite = ( - self._configure_squeeze_excite( - channels=inner_channels, - ) - if has_se - else None - ) - self._pointwise = self._configure_pointwise(in_channels=inner_channels) +@attr.s(auto_attribs=True, eq=False) +class InvertedBottleneck(BaseModule): + """Inverted bottleneck module.""" + + in_channels: int = attr.ib() + out_channels: int = attr.ib() - def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: - """Expansion phase.""" - return nn.Sequential( + def _build(self) -> None: + self.block = nn.Sequential( nn.Conv2d( in_channels=self.in_channels, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=1, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), nn.Mish(inplace=True), ) - def _configure_depthwise( - self, - channels: int, - groups: int, - ) -> nn.Sequential: - return nn.Sequential( + +@attr.s(auto_attribs=True, eq=False) +class Depthwise(BaseModule): + """Depthwise convolution module.""" + + channels: int = attr.ib() + kernel_size: int = attr.ib() + stride: int = attr.ib() + + def _build(self) -> None: + self.block = nn.Sequential( nn.Conv2d( - in_channels=channels, - out_channels=channels, + in_channels=self.channels, + out_channels=self.channels, kernel_size=self.kernel_size, stride=self.stride, - groups=groups, + groups=self.channels, bias=False, ), nn.BatchNorm2d( - num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), ) - def _configure_squeeze_excite(self, channels: int) -> nn.Sequential: + +@attr.s(auto_attribs=True, eq=False) +class SqueezeAndExcite(BaseModule): + """Sequeeze and excite module.""" + + in_channels: int = attr.ib() + channels: int = attr.ib() + se_ratio: float = attr.ib() + + def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) - return nn.Sequential( + self.block = nn.Sequential( nn.Conv2d( - in_channels=channels, + in_channels=self.channels, out_channels=num_squeezed_channels, kernel_size=1, ), nn.Mish(inplace=True), nn.Conv2d( in_channels=num_squeezed_channels, - out_channels=channels, + out_channels=self.channels, kernel_size=1, ), ) - def _configure_pointwise(self, in_channels: int) -> nn.Sequential: - return nn.Sequential( + +@attr.s(auto_attribs=True, eq=False) +class Pointwise(BaseModule): + """Pointwise module.""" + + in_channels: int = attr.ib() + out_channels: int = attr.ib() + + def _build(self) -> None: + self.block = nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, bias=False, @@ -139,6 +132,83 @@ class MBConvBlock(nn.Module): ), ) + +@attr.s(eq=False) +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck block.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" + if self.stride == (2, 2): + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() + + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio + self._inverted_bottleneck = ( + InvertedBottleneck( + in_channels=self.in_channels, + out_channels=inner_channels, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + if self.expand_ratio != 1 + else None + ) + + self._depthwise = Depthwise( + channels=inner_channels, + kernel_size=self.kernel_size, + stride=self.stride, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + + self._squeeze_excite = ( + SqueezeAndExcite( + in_channels=self.in_channels, + channels=inner_channels, + se_ratio=self.se_ratio, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + if has_se + else None + ) + + self._pointwise = Pointwise( + in_channels=inner_channels, + out_channels=self.out_channels, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + def _stochastic_depth( self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float] ) -> Tensor: @@ -153,6 +223,7 @@ class MBConvBlock(nn.Module): def forward( self, x: Tensor, stochastic_dropout_rate: Optional[float] = None ) -> Tensor: + """Forward pass.""" residual = x if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) |