diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/mbconv.py | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 7bfd9ba..f01c369 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -39,7 +39,10 @@ class MBConvBlock(nn.Module): def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: @@ -56,14 +59,13 @@ class MBConvBlock(nn.Module): ) self._depthwise = self._configure_depthwise( - in_channels=inner_channels, - out_channels=inner_channels, + channels=inner_channels, groups=inner_channels, ) self._squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, out_channels=inner_channels, + channels=inner_channels, ) if has_se else None @@ -87,37 +89,37 @@ class MBConvBlock(nn.Module): ) def _configure_depthwise( - self, in_channels: int, out_channels: int, groups: int, + self, + channels: int, + groups: int, ) -> nn.Sequential: return nn.Sequential( nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, + in_channels=channels, + out_channels=channels, kernel_size=self.kernel_size, stride=self.stride, groups=groups, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps ), nn.Mish(inplace=True), ) - def _configure_squeeze_excite( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: - num_squeezed_channels = max(1, int(in_channels * self.se_ratio)) + def _configure_squeeze_excite(self, channels: int) -> nn.Sequential: + num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) return nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=channels, out_channels=num_squeezed_channels, kernel_size=1, ), nn.Mish(inplace=True), nn.Conv2d( in_channels=num_squeezed_channels, - out_channels=out_channels, + out_channels=channels, kernel_size=1, ), ) |