From e9753c0c1476d4c5aa614e1f65a8dd4302a1ce5b Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 25 Jun 2021 01:19:12 +0200
Subject: Efficientnet working

---
 .../networks/encoders/efficientnet/efficientnet.py |  68 ++++++-------
 .../networks/encoders/efficientnet/mbconv.py       |  64 ++++++------
 .../networks/encoders/efficientnet/utils.py        | 107 +--------------------
 3 files changed, 71 insertions(+), 168 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 98d58fd..b527d90 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,13 +1,9 @@
 """Efficient net."""
-from typing import Tuple
-
 from torch import nn, Tensor
 
 from .mbconv import MBConvBlock
 from .utils import (
     block_args,
-    calculate_output_image_size,
-    get_same_padding_conv2d,
     round_filters,
     round_repeats,
 )
@@ -28,11 +24,19 @@ class EfficientNet(nn.Module):
         "l2": (4.3, 5.3, 0.5),
     }
 
-    def __init__(self, arch: str, image_size: Tuple[int, int]) -> None:
+    def __init__(
+        self,
+        arch: str,
+        stochastic_dropout_rate: float = 0.2,
+        bn_momentum: float = 0.99,
+        bn_eps: float = 1.0e-3,
+    ) -> None:
         super().__init__()
         assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
         self.arch = self.archs[arch]
-        self.image_size = image_size
+        self.stochastic_dropout_rate = stochastic_dropout_rate
+        self.bn_momentum = 1 - bn_momentum
+        self.bn_eps = bn_eps
         self._conv_stem: nn.Sequential = None
         self._blocks: nn.Sequential = None
         self._conv_head: nn.Sequential = None
@@ -42,57 +46,53 @@ class EfficientNet(nn.Module):
         _block_args = block_args()
         in_channels = 1  # BW
         out_channels = round_filters(32, self.arch)
-        Conv2d = get_same_padding_conv2d(image_size=self.image_size)
         self._conv_stem = nn.Sequential(
-            Conv2d(
+            nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=3,
-                stride=2,
+                stride=(2, 2),
                 bias=False,
             ),
-            nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
             nn.Mish(inplace=True),
         )
-        image_size = calculate_output_image_size(self.image_size, 2)
         self._blocks = nn.ModuleList([])
         for args in _block_args:
             args.in_channels = round_filters(args.in_channels, self.arch)
             args.out_channels = round_filters(args.out_channels, self.arch)
-            args.num_repeat = round_repeats(args.num_repeat, self.arch)
-
-            self._blocks.append(
-                MBConvBlock(
-                    **args,
-                    bn_momentum=bn_momentum,
-                    bn_eps=bn_eps,
-                    image_size=image_size,
-                )
-            )
-            image_size = calculate_output_image_size(image_size, args.stride)
-            if args.num_repeat > 1:
-                args.in_channels = args.out_channels
-                args.stride = 1
-            for _ in range(args.num_repeat - 1):
+            args.num_repeats = round_repeats(args.num_repeats, self.arch)
+            for _ in range(args.num_repeats):
                 self._blocks.append(
                     MBConvBlock(
                         **args,
-                        bn_momentum=bn_momentum,
-                        bn_eps=bn_eps,
-                        image_size=image_size,
+                        bn_momentum=self.bn_momentum,
+                        bn_eps=self.bn_eps,
                     )
                 )
+                args.in_channels = args.out_channels
+                args.stride = 1
 
-        in_channels = args.out_channels
+        in_channels = round_filters(320, self.arch)
         out_channels = round_filters(1280, self.arch)
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
         self._conv_head = nn.Sequential(
-            Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
-            nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
         )
 
     def extract_features(self, x: Tensor) -> Tensor:
         x = self._conv_stem(x)
+        for i, block in enumerate(self._blocks):
+            stochastic_dropout_rate = self.stochastic_dropout_rate
+            if self.stochastic_dropout_rate:
+                stochastic_dropout_rate *= i / len(self._blocks)
+            x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
+        self._conv_head(x)
+        return x
 
     def forward(self, x: Tensor) -> Tensor:
-        pass
+        return self.extract_features(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index fbb3f22..e43771a 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,11 +1,11 @@
 """Mobile inverted residual block."""
-from typing import Any, Optional, Tuple
+from typing import Any, Optional, Union, Tuple
 
 import torch
 from torch import nn, Tensor
-from torch.nn import functional as F
+import torch.nn.functional as F
 
-from .utils import calculate_output_image_size, drop_connection, get_same_padding_conv2d
+from .utils import stochastic_depth
 
 
 class MBConvBlock(nn.Module):
@@ -16,22 +16,30 @@ class MBConvBlock(nn.Module):
         in_channels: int,
         out_channels: int,
         kernel_size: int,
-        stride: int,
+        stride: Union[Tuple[int, int], int],
         bn_momentum: float,
         bn_eps: float,
         se_ratio: float,
         expand_ratio: int,
-        image_size: Optional[Tuple[int, int]],
         *args: Any,
         **kwargs: Any,
     ) -> None:
         super().__init__()
         self.kernel_size = kernel_size
+        self.stride = (stride, ) * 2 if isinstance(stride, int) else stride
         self.bn_momentum = bn_momentum
         self.bn_eps = bn_eps
-        self.in_channels = self.in_channels
+        self.in_channels = in_channels
         self.out_channels = out_channels
 
+        if self.stride == (2, 2):
+            self.pad = [
+                (self.kernel_size - 1) // 2 - 1,
+                (self.kernel_size - 1) // 2,
+            ] * 2
+        else:
+            self.pad = [(self.kernel_size - 1) // 2] * 4
+
         # Placeholders for layers.
         self._inverted_bottleneck: nn.Sequential = None
         self._depthwise: nn.Sequential = None
@@ -39,7 +47,6 @@ class MBConvBlock(nn.Module):
         self._pointwise: nn.Sequential = None
 
         self._build(
-            image_size=image_size,
             in_channels=in_channels,
             out_channels=out_channels,
             kernel_size=kernel_size,
@@ -50,11 +57,10 @@ class MBConvBlock(nn.Module):
 
     def _build(
         self,
-        image_size: Optional[Tuple[int, int]],
         in_channels: int,
         out_channels: int,
         kernel_size: int,
-        stride: int,
+        stride: Union[Tuple[int, int], int],
         expand_ratio: int,
         se_ratio: float,
     ) -> None:
@@ -62,7 +68,6 @@ class MBConvBlock(nn.Module):
         inner_channels = in_channels * expand_ratio
         self._inverted_bottleneck = (
             self._configure_inverted_bottleneck(
-                image_size=image_size,
                 in_channels=in_channels,
                 out_channels=inner_channels,
             )
@@ -71,7 +76,6 @@ class MBConvBlock(nn.Module):
         )
 
         self._depthwise = self._configure_depthwise(
-            image_size=image_size,
             in_channels=inner_channels,
             out_channels=inner_channels,
             groups=inner_channels,
@@ -79,7 +83,6 @@ class MBConvBlock(nn.Module):
             stride=stride,
         )
 
-        image_size = calculate_output_image_size(image_size, stride)
         self._squeeze_excite = (
             self._configure_squeeze_excite(
                 in_channels=inner_channels,
@@ -91,19 +94,17 @@ class MBConvBlock(nn.Module):
         )
 
         self._pointwise = self._configure_pointwise(
-            image_size=image_size, in_channels=inner_channels, out_channels=out_channels
+            in_channels=inner_channels, out_channels=out_channels
         )
 
     def _configure_inverted_bottleneck(
         self,
-        image_size: Optional[Tuple[int, int]],
         in_channels: int,
         out_channels: int,
     ) -> nn.Sequential:
         """Expansion phase."""
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
         return nn.Sequential(
-            Conv2d(
+            nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=1,
@@ -117,16 +118,14 @@ class MBConvBlock(nn.Module):
 
     def _configure_depthwise(
         self,
-        image_size: Optional[Tuple[int, int]],
         in_channels: int,
         out_channels: int,
         groups: int,
         kernel_size: int,
-        stride: int,
+        stride: Union[Tuple[int, int], int],
     ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
         return nn.Sequential(
-            Conv2d(
+            nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=kernel_size,
@@ -143,16 +142,15 @@ class MBConvBlock(nn.Module):
     def _configure_squeeze_excite(
         self, in_channels: int, out_channels: int, se_ratio: float
     ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=(1, 1))
         num_squeezed_channels = max(1, int(in_channels * se_ratio))
         return nn.Sequential(
-            Conv2d(
+            nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=num_squeezed_channels,
                 kernel_size=1,
             ),
             nn.Mish(inplace=True),
-            Conv2d(
+            nn.Conv2d(
                 in_channels=num_squeezed_channels,
                 out_channels=out_channels,
                 kernel_size=1,
@@ -160,11 +158,10 @@ class MBConvBlock(nn.Module):
         )
 
     def _configure_pointwise(
-        self, image_size: Optional[Tuple[int, int]], in_channels: int, out_channels: int
+        self, in_channels: int, out_channels: int
     ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
         return nn.Sequential(
-            Conv2d(
+            nn.Conv2d(
                 in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=1,
@@ -176,20 +173,23 @@ class MBConvBlock(nn.Module):
         )
 
     def _stochastic_depth(
-        self, x: Tensor, residual: Tensor, drop_connection_rate: Optional[float]
+        self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
     ) -> Tensor:
-        if self.id_skip and self.stride == 1 and self.in_channels == self.out_channels:
-            if drop_connection_rate:
-                x = drop_connection(x, p=drop_connection_rate, training=self.training)
+        if self.stride == (1, 1) and self.in_channels == self.out_channels:
+            if stochastic_dropout_rate:
+                x = stochastic_depth(
+                    x, p=stochastic_dropout_rate, training=self.training
+                )
             x += residual
         return x
 
     def forward(
-        self, x: Tensor, drop_connection_rate: Optional[float] = None
+        self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
     ) -> Tensor:
         residual = x
         if self._inverted_bottleneck is not None:
             x = self._inverted_bottleneck(x)
+        x = F.pad(x, self.pad)
 
         x = self._depthwise(x)
 
@@ -201,5 +201,5 @@ class MBConvBlock(nn.Module):
         x = self._pointwise(x)
 
         # Stochastic depth
-        x = self._stochastic_depth(x, residual, drop_connection_rate)
+        x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
         return x
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index ff52485..6f293db 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -1,27 +1,15 @@
 """Util functions for efficient net."""
 from functools import partial
 import math
-from typing import Any, Optional, Tuple, Type
+from typing import Any, Optional, Union, Tuple, Type
 
 from omegaconf import OmegaConf
 import torch
-from torch import nn, Tensor
-import torch.functional as F
+from torch import Tensor
 
 
-def calculate_output_image_size(
-    image_size: Optional[Tuple[int, int]], stride: int
-) -> Optional[Tuple[int, int]]:
-    """Calculates the output image size when using conv2d with same padding."""
-    if image_size is None:
-        return None
-    height = int(math.ceil(image_size[0] / stride))
-    width = int(math.ceil(image_size[1] / stride))
-    return height, width
-
-
-def drop_connection(x: Tensor, p: float, training: bool) -> Tensor:
-    """Drop connection.
+def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
+    """Stochastic connection.
 
     Drops the entire convolution with a given survival probability.
 
@@ -57,91 +45,6 @@ def drop_connection(x: Tensor, p: float, training: bool) -> Tensor:
     return out
 
 
-def get_same_padding_conv2d(image_size: Optional[Tuple[int, int]]) -> Type[nn.Conv2d]:
-    if image_size is None:
-        return Conv2dDynamicSamePadding
-    return partial(Conv2dStaticSamePadding, image_size=image_size)
-
-
-class Conv2dDynamicSamePadding(nn.Conv2d):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        stride: int = 1,
-        dilation: int = 1,
-        groups: int = 1,
-        bias: bool = True,
-    ) -> None:
-        super().__init__(
-            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
-        )
-        self.stride = [self.stride] * 2
-
-    def forward(self, x: Tensor) -> Tensor:
-        ih, iw = x.shape[-2:]
-        kh, kw = self.weight.shape[-2:]
-        sh, sw = self.stride
-        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
-        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
-        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
-        if pad_h > 0 or pad_w > 0:
-            x = F.pad(
-                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
-            )
-        return F.conv2d(
-            x,
-            self.weight,
-            self.bias,
-            self.stride,
-            self.padding,
-            self.dilation,
-            self.groups,
-        )
-
-
-class Conv2dStaticSamePadding(nn.Conv2d):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        image_size: Tuple[int, int],
-        stride: int = 1,
-        **kwargs: Any
-    ):
-        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
-        self.stride = [self.stride] * 2
-
-        # Calculate padding based on image size and save it.
-        ih, iw = image_size
-        kh, kw = self.weight.shape[-2:]
-        sh, sw = self.stride
-        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
-        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
-        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
-        if pad_h > 0 or pad_w > 0:
-            self.static_padding = nn.ZeroPad2d(
-                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
-            )
-        else:
-            self.static_padding = nn.Identity()
-
-    def forward(self, x: Tensor) -> Tensor:
-        x = self.static_padding(x)
-        x = F.pad(
-            x,
-            self.weight,
-            self.bias,
-            self.stride,
-            self.padding,
-            self.dilation,
-            self.groups,
-        )
-        return x
-
-
 def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
     multiplier = arch[0]
     divisor = 8
@@ -160,7 +63,7 @@ def block_args():
     keys = [
         "num_repeats",
         "kernel_size",
-        "strides",
+        "stride",
         "expand_ratio",
         "in_channels",
         "out_channels",
-- 
cgit v1.2.3-70-g09d2