diff options
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/mbconv.py')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 139 |
1 files changed, 56 insertions, 83 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 3aa63d0..e85df87 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,76 +1,62 @@ """Mobile inverted residual block.""" -from typing import Any, Optional, Union, Tuple +from typing import Optional, Sequence, Union, Tuple +import attr import torch from torch import nn, Tensor import torch.nn.functional as F -from .utils import stochastic_depth +from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth +def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: + """Converts int to tuple.""" + return ( + (stride,) * 2 if isinstance(stride, int) else stride + ) + + +@attr.s(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - bn_momentum: float, - bn_eps: float, - se_ratio: float, - expand_ratio: int, - *args: Any, - **kwargs: Any, - ) -> None: + def __attrs_pre_init__(self) -> None: super().__init__() - self.kernel_size = kernel_size - self.stride = (stride,) * 2 if isinstance(stride, int) else stride - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.in_channels = in_channels - self.out_channels = out_channels + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: nn.Sequential = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" if self.stride == (2, 2): - self.pad = [ + return ( (self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2, - ] * 2 - else: - self.pad = [(self.kernel_size - 1) // 2] * 4 - - # Placeholders for layers. - self._inverted_bottleneck: nn.Sequential = None - self._depthwise: nn.Sequential = None - self._squeeze_excite: nn.Sequential = None - self._pointwise: nn.Sequential = None - - self._build( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - expand_ratio=expand_ratio, - se_ratio=se_ratio, - ) + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() - def _build( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - expand_ratio: int, - se_ratio: float, - ) -> None: - has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 - inner_channels = in_channels * expand_ratio + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio self._inverted_bottleneck = ( - self._configure_inverted_bottleneck( - in_channels=in_channels, out_channels=inner_channels, - ) - if expand_ratio != 1 + self._configure_inverted_bottleneck(out_channels=inner_channels) + if self.expand_ratio != 1 else None ) @@ -78,31 +64,23 @@ class MBConvBlock(nn.Module): in_channels=inner_channels, out_channels=inner_channels, groups=inner_channels, - kernel_size=kernel_size, - stride=stride, ) self._squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, - out_channels=inner_channels, - se_ratio=se_ratio, + in_channels=inner_channels, out_channels=inner_channels, ) if has_se else None ) - self._pointwise = self._configure_pointwise( - in_channels=inner_channels, out_channels=out_channels - ) + self._pointwise = self._configure_pointwise(in_channels=inner_channels) - def _configure_inverted_bottleneck( - self, in_channels: int, out_channels: int, - ) -> nn.Sequential: + def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: """Expansion phase.""" return nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, bias=False, @@ -114,19 +92,14 @@ class MBConvBlock(nn.Module): ) def _configure_depthwise( - self, - in_channels: int, - out_channels: int, - groups: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], + self, in_channels: int, out_channels: int, groups: int, ) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, + kernel_size=self.kernel_size, + stride=self.stride, groups=groups, bias=False, ), @@ -137,9 +110,9 @@ class MBConvBlock(nn.Module): ) def _configure_squeeze_excite( - self, in_channels: int, out_channels: int, se_ratio: float + self, in_channels: int, out_channels: int ) -> nn.Sequential: - num_squeezed_channels = max(1, int(in_channels * se_ratio)) + num_squeezed_channels = max(1, int(in_channels * self.se_ratio)) return nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -154,18 +127,18 @@ class MBConvBlock(nn.Module): ), ) - def _configure_pointwise( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _configure_pointwise(self, in_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=1, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) @@ -186,8 +159,8 @@ class MBConvBlock(nn.Module): residual = x if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) - x = F.pad(x, self.pad) + x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: |