From 14760428dd457f3749c6513ad34b822b05d6a742 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 10 Jun 2022 00:31:48 +0200
Subject: Fix efficientnet

---
 .../networks/efficientnet/efficientnet.py          |  5 ++--
 text_recognizer/networks/efficientnet/mbconv.py    | 28 ++++++++++++----------
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 2260ee2..2f2508d 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -31,11 +31,11 @@ class EfficientNet(nn.Module):
     def __init__(
         self,
         arch: str,
-        params: Tuple[float, float, float],
         stochastic_dropout_rate: float = 0.2,
         bn_momentum: float = 0.99,
         bn_eps: float = 1.0e-3,
         depth: int = 7,
+        out_channels: int = 1280,
     ) -> None:
         super().__init__()
         self.params = self._get_arch_params(arch)
@@ -43,7 +43,7 @@ class EfficientNet(nn.Module):
         self.bn_momentum = bn_momentum
         self.bn_eps = bn_eps
         self.depth = depth
-        self.out_channels: int
+        self.out_channels: int = out_channels
         self._conv_stem: nn.Sequential
         self._blocks: nn.ModuleList
         self._conv_head: nn.Sequential
@@ -92,7 +92,6 @@ class EfficientNet(nn.Module):
                 args.stride = 1
 
         in_channels = round_filters(_block_args[-1].out_channels, self.params)
-        self.out_channels = round_filters(1280, self.params)
         self._conv_head = nn.Sequential(
             nn.Conv2d(
                 in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index e090542..9c97925 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -11,12 +11,11 @@ from text_recognizer.networks.efficientnet.utils import stochastic_depth
 class BaseModule(nn.Module):
     """Base sub module class."""
 
-    def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None:
+    def __init__(self, bn_momentum: float, bn_eps: float) -> None:
         super().__init__()
 
         self.bn_momentum = bn_momentum
         self.bn_eps = bn_eps
-        self.block = block
         self._build()
 
     def _build(self) -> None:
@@ -34,13 +33,12 @@ class InvertedBottleneck(BaseModule):
         self,
         bn_momentum: float,
         bn_eps: float,
-        block: nn.Sequential,
         in_channels: int,
         out_channels: int,
     ) -> None:
-        super().__init__(bn_momentum, bn_eps, block)
         self.in_channels = in_channels
         self.out_channels = out_channels
+        super().__init__(bn_momentum, bn_eps)
 
     def _build(self) -> None:
         self.block = nn.Sequential(
@@ -66,15 +64,14 @@ class Depthwise(BaseModule):
         self,
         bn_momentum: float,
         bn_eps: float,
-        block: nn.Sequential,
         channels: int,
         kernel_size: int,
         stride: int,
     ) -> None:
-        super().__init__(bn_momentum, bn_eps, block)
         self.channels = channels
         self.kernel_size = kernel_size
         self.stride = stride
+        super().__init__(bn_momentum, bn_eps)
 
     def _build(self) -> None:
         self.block = nn.Sequential(
@@ -84,7 +81,6 @@ class Depthwise(BaseModule):
                 kernel_size=self.kernel_size,
                 stride=self.stride,
                 groups=self.channels,
-                padding="same",
                 bias=False,
             ),
             nn.BatchNorm2d(
@@ -101,16 +97,14 @@ class SqueezeAndExcite(BaseModule):
         self,
         bn_momentum: float,
         bn_eps: float,
-        block: nn.Sequential,
         in_channels: int,
         channels: int,
         se_ratio: float,
     ) -> None:
-        super().__init__(bn_momentum, bn_eps, block)
-
         self.in_channels = in_channels
         self.channels = channels
         self.se_ratio = se_ratio
+        super().__init__(bn_momentum, bn_eps)
 
     def _build(self) -> None:
         num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -136,13 +130,12 @@ class Pointwise(BaseModule):
         self,
         bn_momentum: float,
         bn_eps: float,
-        block: nn.Sequential,
         in_channels: int,
         out_channels: int,
     ) -> None:
-        super().__init__(bn_momentum, bn_eps, block)
         self.in_channels = in_channels
         self.out_channels = out_channels
+        super().__init__(bn_momentum, bn_eps)
 
     def _build(self) -> None:
         self.block = nn.Sequential(
@@ -182,6 +175,7 @@ class MBConvBlock(nn.Module):
         self.bn_momentum = bn_momentum
         self.bn_eps = bn_eps
         self.se_ratio = se_ratio
+        self.pad = self._configure_padding()
         self.expand_ratio = expand_ratio
         self._inverted_bottleneck: Optional[InvertedBottleneck]
         self._depthwise: nn.Sequential
@@ -189,6 +183,15 @@ class MBConvBlock(nn.Module):
         self._pointwise: nn.Sequential
         self._build()
 
+    def _configure_padding(self) -> Tuple[int, int, int, int]:
+        """Set padding for convolutional layers."""
+        if self.stride == (2, 2):
+            return (
+                (self.kernel_size - 1) // 2 - 1,
+                (self.kernel_size - 1) // 2,
+            ) * 2
+        return ((self.kernel_size - 1) // 2,) * 4
+
     def _build(self) -> None:
         has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
         inner_channels = self.in_channels * self.expand_ratio
@@ -249,6 +252,7 @@ class MBConvBlock(nn.Module):
         if self._inverted_bottleneck is not None:
             x = self._inverted_bottleneck(x)
 
+        x = F.pad(x, self.pad)
         x = self._depthwise(x)
 
         if self._squeeze_excite is not None:
-- 
cgit v1.2.3-70-g09d2