summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/mbconv.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index e090542..9c97925 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -11,12 +11,11 @@ from text_recognizer.networks.efficientnet.utils import stochastic_depth
class BaseModule(nn.Module):
"""Base sub module class."""
- def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None:
+ def __init__(self, bn_momentum: float, bn_eps: float) -> None:
super().__init__()
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
- self.block = block
self._build()
def _build(self) -> None:
@@ -34,13 +33,12 @@ class InvertedBottleneck(BaseModule):
self,
bn_momentum: float,
bn_eps: float,
- block: nn.Sequential,
in_channels: int,
out_channels: int,
) -> None:
- super().__init__(bn_momentum, bn_eps, block)
self.in_channels = in_channels
self.out_channels = out_channels
+ super().__init__(bn_momentum, bn_eps)
def _build(self) -> None:
self.block = nn.Sequential(
@@ -66,15 +64,14 @@ class Depthwise(BaseModule):
self,
bn_momentum: float,
bn_eps: float,
- block: nn.Sequential,
channels: int,
kernel_size: int,
stride: int,
) -> None:
- super().__init__(bn_momentum, bn_eps, block)
self.channels = channels
self.kernel_size = kernel_size
self.stride = stride
+ super().__init__(bn_momentum, bn_eps)
def _build(self) -> None:
self.block = nn.Sequential(
@@ -84,7 +81,6 @@ class Depthwise(BaseModule):
kernel_size=self.kernel_size,
stride=self.stride,
groups=self.channels,
- padding="same",
bias=False,
),
nn.BatchNorm2d(
@@ -101,16 +97,14 @@ class SqueezeAndExcite(BaseModule):
self,
bn_momentum: float,
bn_eps: float,
- block: nn.Sequential,
in_channels: int,
channels: int,
se_ratio: float,
) -> None:
- super().__init__(bn_momentum, bn_eps, block)
-
self.in_channels = in_channels
self.channels = channels
self.se_ratio = se_ratio
+ super().__init__(bn_momentum, bn_eps)
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -136,13 +130,12 @@ class Pointwise(BaseModule):
self,
bn_momentum: float,
bn_eps: float,
- block: nn.Sequential,
in_channels: int,
out_channels: int,
) -> None:
- super().__init__(bn_momentum, bn_eps, block)
self.in_channels = in_channels
self.out_channels = out_channels
+ super().__init__(bn_momentum, bn_eps)
def _build(self) -> None:
self.block = nn.Sequential(
@@ -182,6 +175,7 @@ class MBConvBlock(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.se_ratio = se_ratio
+ self.pad = self._configure_padding()
self.expand_ratio = expand_ratio
self._inverted_bottleneck: Optional[InvertedBottleneck]
self._depthwise: nn.Sequential
@@ -189,6 +183,15 @@ class MBConvBlock(nn.Module):
self._pointwise: nn.Sequential
self._build()
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
+ if self.stride == (2, 2):
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
def _build(self) -> None:
has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
inner_channels = self.in_channels * self.expand_ratio
@@ -249,6 +252,7 @@ class MBConvBlock(nn.Module):
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None: