diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv_block.py | 62 |
1 files changed, 45 insertions, 17 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py b/text_recognizer/networks/encoders/efficientnet/mbconv_block.py index 0384cd9..c501777 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv_block.py @@ -1,5 +1,5 @@ """Mobile inverted residual block.""" -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn, Tensor @@ -14,6 +14,7 @@ class MBConvBlock(nn.Module): def __init__( self, in_channels: int, + out_channels: int, kernel_size: int, stride: int, bn_momentum: float, @@ -28,12 +29,36 @@ class MBConvBlock(nn.Module): self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.id_skip = id_skip - self.has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 - + ( + self._inverted_bottleneck, + self._depthwise, + self._squeeze_excite, + self._pointwise, + ) = self._build( + image_size=image_size, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + expand_ratio=expand_ratio, + se_ratio=se_ratio, + ) - def _build(self, image_size: Tuple[int, int], in_channels: int, kernel_size: int, stride: int, expand_ratio: int) -> None: + def _build( + self, + image_size: Tuple[int, int], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + expand_ratio: int, + se_ratio: float, + ) -> Tuple[ + Optional[nn.Sequential], nn.Sequential, Optional[nn.Sequential], nn.Sequential + ]: + has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 inner_channels = in_channels * expand_ratio - self._inverted_bottleneck = ( + inverted_bottleneck = ( self._configure_inverted_bottleneck( image_size=image_size, in_channels=in_channels, @@ -43,9 +68,9 @@ class MBConvBlock(nn.Module): else None ) - self._depthwise = self._configure_depthwise( + depthwise = self._configure_depthwise( image_size=image_size, - in_channels=in_channels, + in_channels=inner_channels, out_channels=inner_channels, groups=inner_channels, kernel_size=kernel_size, @@ -53,17 +78,20 @@ class MBConvBlock(nn.Module): ) image_size = calculate_output_image_size(image_size, stride) - self._squeeze_excite = ( + squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, out_channels=inner_channels, se_ratio=se_ratio + in_channels=inner_channels, + out_channels=inner_channels, + se_ratio=se_ratio, ) - if self.has_se + if has_se else None ) - self._pointwise = self._configure_pointwise( + pointwise = self._configure_pointwise( image_size=image_size, in_channels=inner_channels, out_channels=out_channels ) + return inverted_bottleneck, depthwise, squeeze_excite, pointwise def _configure_inverted_bottleneck( self, @@ -83,7 +111,7 @@ class MBConvBlock(nn.Module): nn.BatchNorm2d( num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), - nn.SiLU(inplace=True), + nn.Mish(inplace=True), ) def _configure_depthwise( @@ -108,7 +136,7 @@ class MBConvBlock(nn.Module): nn.BatchNorm2d( num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps ), - nn.SiLU(inplace=True), + nn.Mish(inplace=True), ) def _configure_squeeze_excite( @@ -122,7 +150,7 @@ class MBConvBlock(nn.Module): out_channels=num_squeezed_channels, kernel_size=1, ), - nn.SiLU(inplace=True), + nn.Mish(inplace=True), Conv2d( in_channels=num_squeezed_channels, out_channels=out_channels, @@ -156,8 +184,8 @@ class MBConvBlock(nn.Module): if self._squeeze_excite is not None: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._squeeze_excite(x) - x = torch.sigmoid(x_squeezed) * x - + x = torch.tanh(F.softplus(x_squeezed)) * x + x = self._pointwise(x) - + # Stochastic depth |