From 9995922ff957ce424dca0655a01d8338a519aa86 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 20 Jun 2021 22:17:49 +0200
Subject: Working on new implementation of efficientnet

---
 text_recognizer/networks/__init__.py               |   2 +-
 text_recognizer/networks/encoders/__init__.py      |   2 +-
 text_recognizer/networks/encoders/efficientnet.py  | 145 ---------------
 .../networks/encoders/efficientnet/efficientnet.py |   9 +
 .../networks/encoders/efficientnet/mbconv.py       | 205 +++++++++++++++++++++
 .../networks/encoders/efficientnet/mbconv_block.py | 191 -------------------
 .../networks/encoders/efficientnet/utils.py        | 141 ++++++++++++++
 7 files changed, 357 insertions(+), 338 deletions(-)
 delete mode 100644 text_recognizer/networks/encoders/efficientnet.py
 create mode 100644 text_recognizer/networks/encoders/efficientnet/efficientnet.py
 create mode 100644 text_recognizer/networks/encoders/efficientnet/mbconv.py
 delete mode 100644 text_recognizer/networks/encoders/efficientnet/mbconv_block.py
 create mode 100644 text_recognizer/networks/encoders/efficientnet/utils.py

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index d1ebf1a..618450f 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1,5 +1,5 @@
 """Network modules"""
-from .encoders import EfficientNet
+# from .encoders import EfficientNet
 from .vqvae import VQVAE
 
 # from .cnn_transformer import CNNTransformer
diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py
index 25aed0e..b526b0c 100644
--- a/text_recognizer/networks/encoders/__init__.py
+++ b/text_recognizer/networks/encoders/__init__.py
@@ -1,2 +1,2 @@
 """Vision backbones."""
-from .efficientnet import EfficientNet
+# from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/encoders/efficientnet.py b/text_recognizer/networks/encoders/efficientnet.py
deleted file mode 100644
index 61dea77..0000000
--- a/text_recognizer/networks/encoders/efficientnet.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Efficient net b0 implementation."""
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class ConvNorm(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        stride: int,
-        padding: int,
-        groups: int = 1,
-    ) -> None:
-        super().__init__()
-        self.block = nn.Sequential(
-            nn.Conv2d(
-                in_channels=in_channels,
-                out_channels=out_channels,
-                kernel_size=kernel_size,
-                stride=stride,
-                padding=padding,
-                groups=groups,
-                bias=False,
-            ),
-            nn.BatchNorm2d(num_features=out_channels),
-            nn.SiLU(inplace=True),
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        return self.block(x)
-
-
-class SqueezeExcite(nn.Module):
-    def __init__(self, in_channels: int, reduce_dim: int) -> None:
-        super().__init__()
-        self.se = nn.Sequential(
-            nn.AdaptiveAvgPool2d(1),  # [C, H, W] -> [C, 1, 1]
-            nn.Conv2d(in_channels=in_channels, out_channels=reduce_dim, kernel_size=1),
-            nn.SiLU(),
-            nn.Conv2d(in_channels=reduce_dim, out_channels=in_channels, kernel_size=1),
-            nn.Sigmoid(),
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        return x * self.se(x)
-
-
-class InvertedResidulaBlock(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        stride: int,
-        padding: int,
-        expand_ratio: float,
-        reduction: int = 4,
-        survival_prob: float = 0.8,
-    ) -> None:
-        super().__init__()
-        self.survival_prob = survival_prob
-        self.use_residual = in_channels == out_channels and stride == 1
-        hidden_dim = in_channels * expand_ratio
-        self.expand = in_channels != hidden_dim
-        reduce_dim = in_channels // reduction
-
-        if self.expand:
-            self.expand_conv = ConvNorm(
-                in_channels, hidden_dim, kernel_size=3, stride=1, padding=1
-            )
-
-        self.conv = nn.Sequential(
-            ConvNorm(
-                hidden_dim, hidden_dim, kernel_size, stride, padding, groups=hidden_dim
-            ),
-            SqueezeExcite(hidden_dim, reduce_dim),
-            nn.Conv2d(
-                in_channels=hidden_dim,
-                out_channels=out_channels,
-                kernel_size=1,
-                bias=False,
-            ),
-            nn.BatchNorm2d(num_features=out_channels),
-        )
-
-    def stochastic_depth(self, x: Tensor) -> Tensor:
-        if not self.training:
-            return x
-
-        binary_tensor = (
-            torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
-        )
-        return torch.div(x, self.survival_prob) * binary_tensor
-
-    def forward(self, x: Tensor) -> Tensor:
-        out = self.expand_conv(x) if self.expand else x
-        if self.use_residual:
-            return self.stochastic_depth(self.conv(out)) + x
-        return self.conv(out)
-
-
-class EfficientNet(nn.Module):
-    """Efficient net b0 backbone."""
-
-    def __init__(self) -> None:
-        super().__init__()
-        self.base_model = [
-            # expand_ratio, channels, repeats, stride, kernel_size
-            [1, 16, 1, 1, 3],
-            [6, 24, 2, 2, 3],
-            [6, 40, 2, 2, 5],
-            [6, 80, 3, 2, 3],
-            [6, 112, 3, 1, 5],
-            [6, 192, 4, 2, 5],
-            [6, 320, 1, 1, 3],
-        ]
-
-        self.backbone = self._build_b0()
-
-    def _build_b0(self) -> nn.Sequential:
-        in_channels = 32
-        layers = [ConvNorm(1, in_channels, 3, stride=2, padding=1)]
-
-        for expand_ratio, out_channels, repeats, stride, kernel_size in self.base_model:
-            for i in range(repeats):
-                layers.append(
-                    InvertedResidulaBlock(
-                        in_channels,
-                        out_channels,
-                        expand_ratio=expand_ratio,
-                        stride=stride if i == 0 else 1,
-                        kernel_size=kernel_size,
-                        padding=kernel_size // 2,
-                    )
-                )
-                in_channels = out_channels
-        layers.append(ConvNorm(in_channels, 256, kernel_size=1, stride=1, padding=0))
-
-        return nn.Sequential(*layers)
-
-    def forward(self, x: Tensor) -> Tensor:
-        return self.backbone(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
new file mode 100644
index 0000000..d953c10
--- /dev/null
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -0,0 +1,9 @@
+"""Efficient net."""
+from torch import nn, Tensor
+
+
+class EfficientNet(nn.Module):
+    def __init__(
+        self,
+    ) -> None:
+        super().__init__()
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
new file mode 100644
index 0000000..602aeb7
--- /dev/null
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -0,0 +1,205 @@
+"""Mobile inverted residual block."""
+from typing import Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from .utils import calculate_output_image_size, drop_connection, get_same_padding_conv2d
+
+
+class MBConvBlock(nn.Module):
+    """Mobile Inverted Residual Bottleneck block."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int,
+        bn_momentum: float,
+        bn_eps: float,
+        se_ratio: float,
+        id_skip: bool,
+        expand_ratio: int,
+        image_size: Optional[Tuple[int, int]],
+    ) -> None:
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.bn_momentum = bn_momentum
+        self.bn_eps = bn_eps
+        self.id_skip = id_skip
+        self.in_channels = self.in_channels
+        self.out_channels = out_channels
+
+        # Placeholders for layers.
+        self._inverted_bottleneck: nn.Sequential = None
+        self._depthwise: nn.Sequential = None
+        self._squeeze_excite: nn.Sequential = None
+        self._pointwise: nn.Sequential = None
+
+        self._build(
+            image_size=image_size,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            expand_ratio=expand_ratio,
+            se_ratio=se_ratio,
+        )
+
+    def _build(
+        self,
+        image_size: Optional[Tuple[int, int]],
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int,
+        expand_ratio: int,
+        se_ratio: float,
+    ) -> None:
+        has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
+        inner_channels = in_channels * expand_ratio
+        self._inverted_bottleneck = (
+            self._configure_inverted_bottleneck(
+                image_size=image_size,
+                in_channels=in_channels,
+                out_channels=inner_channels,
+            )
+            if expand_ratio != 1
+            else None
+        )
+
+        self._depthwise = self._configure_depthwise(
+            image_size=image_size,
+            in_channels=inner_channels,
+            out_channels=inner_channels,
+            groups=inner_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+        )
+
+        image_size = calculate_output_image_size(image_size, stride)
+        self._squeeze_excite = (
+            self._configure_squeeze_excite(
+                in_channels=inner_channels,
+                out_channels=inner_channels,
+                se_ratio=se_ratio,
+            )
+            if has_se
+            else None
+        )
+
+        self._pointwise = self._configure_pointwise(
+            image_size=image_size, in_channels=inner_channels, out_channels=out_channels
+        )
+
+    def _configure_inverted_bottleneck(
+        self,
+        image_size: Optional[Tuple[int, int]],
+        in_channels: int,
+        out_channels: int,
+    ) -> nn.Sequential:
+        """Expansion phase."""
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        return nn.Sequential(
+            Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
+            nn.Mish(inplace=True),
+        )
+
+    def _configure_depthwise(
+        self,
+        image_size: Optional[Tuple[int, int]],
+        in_channels: int,
+        out_channels: int,
+        groups: int,
+        kernel_size: int,
+        stride: int,
+    ) -> nn.Sequential:
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        return nn.Sequential(
+            Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                groups=groups,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
+            nn.Mish(inplace=True),
+        )
+
+    def _configure_squeeze_excite(
+        self, in_channels: int, out_channels: int, se_ratio: float
+    ) -> nn.Sequential:
+        Conv2d = get_same_padding_conv2d(image_size=(1, 1))
+        num_squeezed_channels = max(1, int(in_channels * se_ratio))
+        return nn.Sequential(
+            Conv2d(
+                in_channels=in_channels,
+                out_channels=num_squeezed_channels,
+                kernel_size=1,
+            ),
+            nn.Mish(inplace=True),
+            Conv2d(
+                in_channels=num_squeezed_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            ),
+        )
+
+    def _configure_pointwise(
+        self, image_size: Optional[Tuple[int, int]], in_channels: int, out_channels: int
+    ) -> nn.Sequential:
+        Conv2d = get_same_padding_conv2d(image_size=image_size)
+        return nn.Sequential(
+            Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
+        )
+
+    def _stochastic_depth(
+        self, x: Tensor, residual: Tensor, drop_connection_rate: Optional[float]
+    ) -> Tensor:
+        if self.id_skip and self.stride == 1 and self.in_channels == self.out_channels:
+            if drop_connection_rate:
+                x = drop_connection(x, p=drop_connection_rate, training=self.training)
+            x += residual
+        return x
+
+    def forward(
+        self, x: Tensor, drop_connection_rate: Optional[float] = None
+    ) -> Tensor:
+        residual = x
+        if self._inverted_bottleneck is not None:
+            x = self._inverted_bottleneck(x)
+
+        x = self._depthwise(x)
+
+        if self._squeeze_excite is not None:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._squeeze_excite(x)
+            x = torch.tanh(F.softplus(x_squeezed)) * x
+
+        x = self._pointwise(x)
+
+        # Stochastic depth
+        x = self._stochastic_depth(x, residual, drop_connection_rate)
+        return x
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py b/text_recognizer/networks/encoders/efficientnet/mbconv_block.py
deleted file mode 100644
index c501777..0000000
--- a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py
+++ /dev/null
@@ -1,191 +0,0 @@
-"""Mobile inverted residual block."""
-from typing import Optional, Tuple
-
-import torch
-from torch import nn, Tensor
-from torch.nn import functional as F
-
-from .utils import get_same_padding_conv2d
-
-
-class MBConvBlock(nn.Module):
-    """Mobile Inverted Residual Bottleneck block."""
-
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        stride: int,
-        bn_momentum: float,
-        bn_eps: float,
-        se_ratio: float,
-        id_skip: bool,
-        expand_ratio: int,
-        image_size: Tuple[int, int],
-    ) -> None:
-        super().__init__()
-        self.kernel_size = kernel_size
-        self.bn_momentum = bn_momentum
-        self.bn_eps = bn_eps
-        self.id_skip = id_skip
-        (
-            self._inverted_bottleneck,
-            self._depthwise,
-            self._squeeze_excite,
-            self._pointwise,
-        ) = self._build(
-            image_size=image_size,
-            in_channels=in_channels,
-            out_channels=out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            expand_ratio=expand_ratio,
-            se_ratio=se_ratio,
-        )
-
-    def _build(
-        self,
-        image_size: Tuple[int, int],
-        in_channels: int,
-        out_channels: int,
-        kernel_size: int,
-        stride: int,
-        expand_ratio: int,
-        se_ratio: float,
-    ) -> Tuple[
-        Optional[nn.Sequential], nn.Sequential, Optional[nn.Sequential], nn.Sequential
-    ]:
-        has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
-        inner_channels = in_channels * expand_ratio
-        inverted_bottleneck = (
-            self._configure_inverted_bottleneck(
-                image_size=image_size,
-                in_channels=in_channels,
-                out_channels=inner_channels,
-            )
-            if expand_ratio != 1
-            else None
-        )
-
-        depthwise = self._configure_depthwise(
-            image_size=image_size,
-            in_channels=inner_channels,
-            out_channels=inner_channels,
-            groups=inner_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-        )
-
-        image_size = calculate_output_image_size(image_size, stride)
-        squeeze_excite = (
-            self._configure_squeeze_excite(
-                in_channels=inner_channels,
-                out_channels=inner_channels,
-                se_ratio=se_ratio,
-            )
-            if has_se
-            else None
-        )
-
-        pointwise = self._configure_pointwise(
-            image_size=image_size, in_channels=inner_channels, out_channels=out_channels
-        )
-        return inverted_bottleneck, depthwise, squeeze_excite, pointwise
-
-    def _configure_inverted_bottleneck(
-        self,
-        image_size: Tuple[int, int],
-        in_channels: int,
-        out_channels: int,
-    ) -> nn.Sequential:
-        """Expansion phase."""
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
-        return nn.Sequential(
-            Conv2d(
-                in_channels=in_channels,
-                out_channels=out_channels,
-                kernel_size=1,
-                bias=False,
-            ),
-            nn.BatchNorm2d(
-                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
-            ),
-            nn.Mish(inplace=True),
-        )
-
-    def _configure_depthwise(
-        self,
-        image_size: Tuple[int, int],
-        in_channels: int,
-        out_channels: int,
-        groups: int,
-        kernel_size: int,
-        stride: int,
-    ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
-        return nn.Sequential(
-            Conv2d(
-                in_channels=in_channels,
-                out_channels=out_channels,
-                kernel_size=kernel_size,
-                stride=stride,
-                groups=groups,
-                bias=False,
-            ),
-            nn.BatchNorm2d(
-                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
-            ),
-            nn.Mish(inplace=True),
-        )
-
-    def _configure_squeeze_excite(
-        self, in_channels: int, out_channels: int, se_ratio: float
-    ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=(1, 1))
-        num_squeezed_channels = max(1, int(in_channels * se_ratio))
-        return nn.Sequential(
-            Conv2d(
-                in_channels=in_channels,
-                out_channels=num_squeezed_channels,
-                kernel_size=1,
-            ),
-            nn.Mish(inplace=True),
-            Conv2d(
-                in_channels=num_squeezed_channels,
-                out_channels=out_channels,
-                kernel_size=1,
-            ),
-        )
-
-    def _configure_pointwise(
-        self, image_size: Tuple[int, int], in_channels: int, out_channels: int
-    ) -> nn.Sequential:
-        Conv2d = get_same_padding_conv2d(image_size=image_size)
-        return nn.Sequential(
-            Conv2d(
-                in_channels=in_channels,
-                out_channels=out_channels,
-                kernel_size=1,
-                bias=False,
-            ),
-            nn.BatchNorm2d(
-                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
-            ),
-        )
-
-    def forward(self, x: Tensor, drop_connection_rate: Optional[float]) -> Tensor:
-        residual = x
-        if self._inverted_bottleneck is not None:
-            x = self._inverted_bottleneck(x)
-
-        x = self._depthwise(x)
-
-        if self._squeeze_excite is not None:
-            x_squeezed = F.adaptive_avg_pool2d(x, 1)
-            x_squeezed = self._squeeze_excite(x)
-            x = torch.tanh(F.softplus(x_squeezed)) * x
-
-        x = self._pointwise(x)
-
-        # Stochastic depth
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
new file mode 100644
index 0000000..4b4a787
--- /dev/null
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -0,0 +1,141 @@
+"""Util functions for efficient net."""
+from functools import partial
+import math
+from typing import Any, Optional, Tuple, Type
+
+import torch
+from torch import nn, Tensor
+import torch.functional as F
+
+
+def calculate_output_image_size(
+    image_size: Optional[Tuple[int, int]], stride: int
+) -> Optional[Tuple[int, int]]:
+    """Calculates the output image size when using conv2d with same padding."""
+    if image_size is None:
+        return None
+    height = int(math.ceil(image_size[0] / stride))
+    width = int(math.ceil(image_size[1] / stride))
+    return height, width
+
+
+def drop_connection(x: Tensor, p: float, training: bool) -> Tensor:
+    """Drop connection.
+
+    Drops the entire convolution with a given survival probability.
+
+    Args:
+        x (Tensor): Input tensor.
+        p (float): Survival probability between 0.0 and 1.0.
+        training (bool): The running mode.
+
+    Shapes:
+        - x: :math: `(B, C, W, H)`.
+        - out: :math: `(B, C, W, H)`.
+
+        where B is the batch size, C is the number of channels, W is the width, and H
+        is the height.
+
+    Returns:
+        out (Tensor): Output after drop connection.
+    """
+    assert 0.0 <= p <= 1.0, "p must be in range of [0, 1]"
+
+    if not training:
+        return x
+
+    bsz = x.shape[0]
+    survival_prob = 1 - p
+
+    # Generate a binary tensor mask according to probability (p for 0, 1-p for 1)
+    random_tensor = survival_prob
+    random_tensor += torch.rand([bsz, 1, 1, 1]).type_as(x)
+    binary_tensor = torch.floor(random_tensor)
+
+    out = x / survival_prob * binary_tensor
+    return out
+
+
+def get_same_padding_conv2d(image_size: Optional[Tuple[int, int]]) -> Type[nn.Conv2d]:
+    if image_size is None:
+        return Conv2dDynamicSamePadding
+    return partial(Conv2dStaticSamePadding, image_size=image_size)
+
+
+class Conv2dDynamicSamePadding(nn.Conv2d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+    ) -> None:
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
+        )
+        self.stride = [self.stride] * 2
+
+    def forward(self, x: Tensor) -> Tensor:
+        ih, iw = x.shape[-2:]
+        kh, kw = self.weight.shape[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(
+                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+            )
+        return F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+
+
+class Conv2dStaticSamePadding(nn.Conv2d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        image_size: Tuple[int, int],
+        stride: int = 1,
+        **kwargs: Any
+    ):
+        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
+        self.stride = [self.stride] * 2
+
+        # Calculate padding based on image size and save it.
+        ih, iw = image_size
+        kh, kw = self.weight.shape[-2:]
+        sh, sw = self.stride
+        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
+        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
+        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
+        if pad_h > 0 or pad_w > 0:
+            self.static_padding = nn.ZeroPad2d(
+                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
+            )
+        else:
+            self.static_padding = nn.Identity()
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.static_padding(x)
+        x = F.pad(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+        return x
-- 
cgit v1.2.3-70-g09d2