summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 7bfd9ba..f01c369 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -39,7 +39,10 @@ class MBConvBlock(nn.Module):
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None:
@@ -56,14 +59,13 @@ class MBConvBlock(nn.Module):
)
self._depthwise = self._configure_depthwise(
- in_channels=inner_channels,
- out_channels=inner_channels,
+ channels=inner_channels,
groups=inner_channels,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels, out_channels=inner_channels,
+ channels=inner_channels,
)
if has_se
else None
@@ -87,37 +89,37 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self, in_channels: int, out_channels: int, groups: int,
+ self,
+ channels: int,
+ groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
+ in_channels=channels,
+ out_channels=channels,
kernel_size=self.kernel_size,
stride=self.stride,
groups=groups,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps
),
nn.Mish(inplace=True),
)
- def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
+ def _configure_squeeze_excite(self, channels: int) -> nn.Sequential:
+ num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=channels,
out_channels=num_squeezed_channels,
kernel_size=1,
),
nn.Mish(inplace=True),
nn.Conv2d(
in_channels=num_squeezed_channels,
- out_channels=out_channels,
+ out_channels=channels,
kernel_size=1,
),
)