From f8db5953a63c2467a76cc0610eb2ad0c96b69c70 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 6 Nov 2021 11:47:52 +0100
Subject: Fix efficientnet incorrect channel calculation

---
 .../networks/encoders/efficientnet/mbconv.py       | 30 ++++++++++++----------
 1 file changed, 16 insertions(+), 14 deletions(-)

(limited to 'text_recognizer/networks/encoders')

diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 7bfd9ba..f01c369 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -39,7 +39,10 @@ class MBConvBlock(nn.Module):
     def _configure_padding(self) -> Tuple[int, int, int, int]:
         """Set padding for convolutional layers."""
         if self.stride == (2, 2):
-            return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+            return (
+                (self.kernel_size - 1) // 2 - 1,
+                (self.kernel_size - 1) // 2,
+            ) * 2
         return ((self.kernel_size - 1) // 2,) * 4
 
     def __attrs_post_init__(self) -> None:
@@ -56,14 +59,13 @@ class MBConvBlock(nn.Module):
         )
 
         self._depthwise = self._configure_depthwise(
-            in_channels=inner_channels,
-            out_channels=inner_channels,
+            channels=inner_channels,
             groups=inner_channels,
         )
 
         self._squeeze_excite = (
             self._configure_squeeze_excite(
-                in_channels=inner_channels, out_channels=inner_channels,
+                channels=inner_channels,
             )
             if has_se
             else None
@@ -87,37 +89,37 @@ class MBConvBlock(nn.Module):
         )
 
     def _configure_depthwise(
-        self, in_channels: int, out_channels: int, groups: int,
+        self,
+        channels: int,
+        groups: int,
     ) -> nn.Sequential:
         return nn.Sequential(
             nn.Conv2d(
-                in_channels=in_channels,
-                out_channels=out_channels,
+                in_channels=channels,
+                out_channels=channels,
                 kernel_size=self.kernel_size,
                 stride=self.stride,
                 groups=groups,
                 bias=False,
             ),
             nn.BatchNorm2d(
-                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+                num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps
             ),
             nn.Mish(inplace=True),
         )
 
-    def _configure_squeeze_excite(
-        self, in_channels: int, out_channels: int
-    ) -> nn.Sequential:
-        num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
+    def _configure_squeeze_excite(self, channels: int) -> nn.Sequential:
+        num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
         return nn.Sequential(
             nn.Conv2d(
-                in_channels=in_channels,
+                in_channels=channels,
                 out_channels=num_squeezed_channels,
                 kernel_size=1,
             ),
             nn.Mish(inplace=True),
             nn.Conv2d(
                 in_channels=num_squeezed_channels,
-                out_channels=out_channels,
+                out_channels=channels,
                 kernel_size=1,
             ),
         )
-- 
cgit v1.2.3-70-g09d2