diff options
Diffstat (limited to 'text_recognizer/networks')
| -rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 207 | 
1 files changed, 139 insertions, 68 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index f01c369..96626fc 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,5 +1,5 @@  """Mobile inverted residual block.""" -from typing import Optional, Sequence, Union, Tuple +from typing import Optional, Tuple, Union  import attr  import torch @@ -15,119 +15,112 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:  @attr.s(eq=False) -class MBConvBlock(nn.Module): -    """Mobile Inverted Residual Bottleneck block.""" - -    def __attrs_pre_init__(self) -> None: -        super().__init__() +class BaseModule(nn.Module): +    """Base sub module class.""" -    in_channels: int = attr.ib() -    out_channels: int = attr.ib() -    kernel_size: Tuple[int, int] = attr.ib() -    stride: Tuple[int, int] = attr.ib(converter=_convert_stride)      bn_momentum: float = attr.ib()      bn_eps: float = attr.ib() -    se_ratio: float = attr.ib() -    expand_ratio: int = attr.ib() -    pad: Tuple[int, int, int, int] = attr.ib(init=False) -    _inverted_bottleneck: nn.Sequential = attr.ib(init=False) -    _depthwise: nn.Sequential = attr.ib(init=False) -    _squeeze_excite: nn.Sequential = attr.ib(init=False) -    _pointwise: nn.Sequential = attr.ib(init=False) +    block: nn.Sequential = attr.ib(init=False) -    @pad.default -    def _configure_padding(self) -> Tuple[int, int, int, int]: -        """Set padding for convolutional layers.""" -        if self.stride == (2, 2): -            return ( -                (self.kernel_size - 1) // 2 - 1, -                (self.kernel_size - 1) // 2, -            ) * 2 -        return ((self.kernel_size - 1) // 2,) * 4 +    def __attrs_pre_init__(self) -> None: +        super().__init__()      def __attrs_post_init__(self) -> None: -        """Post init configuration."""          self._build()      def _build(self) -> None: -        has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 -        inner_channels = self.in_channels * self.expand_ratio -        self._inverted_bottleneck = ( -            self._configure_inverted_bottleneck(out_channels=inner_channels) -            if self.expand_ratio != 1 -            else None -        ) +        pass -        self._depthwise = self._configure_depthwise( -            channels=inner_channels, -            groups=inner_channels, -        ) +    def forward(self, x: Tensor) -> Tensor: +        """Forward pass.""" +        return self.block(x) -        self._squeeze_excite = ( -            self._configure_squeeze_excite( -                channels=inner_channels, -            ) -            if has_se -            else None -        ) -        self._pointwise = self._configure_pointwise(in_channels=inner_channels) +@attr.s(auto_attribs=True, eq=False) +class InvertedBottleneck(BaseModule): +    """Inverted bottleneck module.""" + +    in_channels: int = attr.ib() +    out_channels: int = attr.ib() -    def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: -        """Expansion phase.""" -        return nn.Sequential( +    def _build(self) -> None: +        self.block = nn.Sequential(              nn.Conv2d(                  in_channels=self.in_channels, -                out_channels=out_channels, +                out_channels=self.out_channels,                  kernel_size=1,                  bias=False,              ),              nn.BatchNorm2d( -                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps +                num_features=self.out_channels, +                momentum=self.bn_momentum, +                eps=self.bn_eps,              ),              nn.Mish(inplace=True),          ) -    def _configure_depthwise( -        self, -        channels: int, -        groups: int, -    ) -> nn.Sequential: -        return nn.Sequential( + +@attr.s(auto_attribs=True, eq=False) +class Depthwise(BaseModule): +    """Depthwise convolution module.""" + +    channels: int = attr.ib() +    kernel_size: int = attr.ib() +    stride: int = attr.ib() + +    def _build(self) -> None: +        self.block = nn.Sequential(              nn.Conv2d( -                in_channels=channels, -                out_channels=channels, +                in_channels=self.channels, +                out_channels=self.channels,                  kernel_size=self.kernel_size,                  stride=self.stride, -                groups=groups, +                groups=self.channels,                  bias=False,              ),              nn.BatchNorm2d( -                num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps +                num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps              ),              nn.Mish(inplace=True),          ) -    def _configure_squeeze_excite(self, channels: int) -> nn.Sequential: + +@attr.s(auto_attribs=True, eq=False) +class SqueezeAndExcite(BaseModule): +    """Sequeeze and excite module.""" + +    in_channels: int = attr.ib() +    channels: int = attr.ib() +    se_ratio: float = attr.ib() + +    def _build(self) -> None:          num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) -        return nn.Sequential( +        self.block = nn.Sequential(              nn.Conv2d( -                in_channels=channels, +                in_channels=self.channels,                  out_channels=num_squeezed_channels,                  kernel_size=1,              ),              nn.Mish(inplace=True),              nn.Conv2d(                  in_channels=num_squeezed_channels, -                out_channels=channels, +                out_channels=self.channels,                  kernel_size=1,              ),          ) -    def _configure_pointwise(self, in_channels: int) -> nn.Sequential: -        return nn.Sequential( + +@attr.s(auto_attribs=True, eq=False) +class Pointwise(BaseModule): +    """Pointwise module.""" + +    in_channels: int = attr.ib() +    out_channels: int = attr.ib() + +    def _build(self) -> None: +        self.block = nn.Sequential(              nn.Conv2d( -                in_channels=in_channels, +                in_channels=self.in_channels,                  out_channels=self.out_channels,                  kernel_size=1,                  bias=False, @@ -139,6 +132,83 @@ class MBConvBlock(nn.Module):              ),          ) + +@attr.s(eq=False) +class MBConvBlock(nn.Module): +    """Mobile Inverted Residual Bottleneck block.""" + +    def __attrs_pre_init__(self) -> None: +        super().__init__() + +    in_channels: int = attr.ib() +    out_channels: int = attr.ib() +    kernel_size: Tuple[int, int] = attr.ib() +    stride: Tuple[int, int] = attr.ib(converter=_convert_stride) +    bn_momentum: float = attr.ib() +    bn_eps: float = attr.ib() +    se_ratio: float = attr.ib() +    expand_ratio: int = attr.ib() +    pad: Tuple[int, int, int, int] = attr.ib(init=False) +    _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) +    _depthwise: nn.Sequential = attr.ib(init=False) +    _squeeze_excite: nn.Sequential = attr.ib(init=False) +    _pointwise: nn.Sequential = attr.ib(init=False) + +    @pad.default +    def _configure_padding(self) -> Tuple[int, int, int, int]: +        """Set padding for convolutional layers.""" +        if self.stride == (2, 2): +            return ( +                (self.kernel_size - 1) // 2 - 1, +                (self.kernel_size - 1) // 2, +            ) * 2 +        return ((self.kernel_size - 1) // 2,) * 4 + +    def __attrs_post_init__(self) -> None: +        """Post init configuration.""" +        self._build() + +    def _build(self) -> None: +        has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 +        inner_channels = self.in_channels * self.expand_ratio +        self._inverted_bottleneck = ( +            InvertedBottleneck( +                in_channels=self.in_channels, +                out_channels=inner_channels, +                bn_momentum=self.bn_momentum, +                bn_eps=self.bn_eps, +            ) +            if self.expand_ratio != 1 +            else None +        ) + +        self._depthwise = Depthwise( +            channels=inner_channels, +            kernel_size=self.kernel_size, +            stride=self.stride, +            bn_momentum=self.bn_momentum, +            bn_eps=self.bn_eps, +        ) + +        self._squeeze_excite = ( +            SqueezeAndExcite( +                in_channels=self.in_channels, +                channels=inner_channels, +                se_ratio=self.se_ratio, +                bn_momentum=self.bn_momentum, +                bn_eps=self.bn_eps, +            ) +            if has_se +            else None +        ) + +        self._pointwise = Pointwise( +            in_channels=inner_channels, +            out_channels=self.out_channels, +            bn_momentum=self.bn_momentum, +            bn_eps=self.bn_eps, +        ) +      def _stochastic_depth(          self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]      ) -> Tensor: @@ -153,6 +223,7 @@ class MBConvBlock(nn.Module):      def forward(          self, x: Tensor, stochastic_dropout_rate: Optional[float] = None      ) -> Tensor: +        """Forward pass."""          residual = x          if self._inverted_bottleneck is not None:              x = self._inverted_bottleneck(x)  |