diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index e090542..9c97925 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -11,12 +11,11 @@ from text_recognizer.networks.efficientnet.utils import stochastic_depth class BaseModule(nn.Module): """Base sub module class.""" - def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None: + def __init__(self, bn_momentum: float, bn_eps: float) -> None: super().__init__() self.bn_momentum = bn_momentum self.bn_eps = bn_eps - self.block = block self._build() def _build(self) -> None: @@ -34,13 +33,12 @@ class InvertedBottleneck(BaseModule): self, bn_momentum: float, bn_eps: float, - block: nn.Sequential, in_channels: int, out_channels: int, ) -> None: - super().__init__(bn_momentum, bn_eps, block) self.in_channels = in_channels self.out_channels = out_channels + super().__init__(bn_momentum, bn_eps) def _build(self) -> None: self.block = nn.Sequential( @@ -66,15 +64,14 @@ class Depthwise(BaseModule): self, bn_momentum: float, bn_eps: float, - block: nn.Sequential, channels: int, kernel_size: int, stride: int, ) -> None: - super().__init__(bn_momentum, bn_eps, block) self.channels = channels self.kernel_size = kernel_size self.stride = stride + super().__init__(bn_momentum, bn_eps) def _build(self) -> None: self.block = nn.Sequential( @@ -84,7 +81,6 @@ class Depthwise(BaseModule): kernel_size=self.kernel_size, stride=self.stride, groups=self.channels, - padding="same", bias=False, ), nn.BatchNorm2d( @@ -101,16 +97,14 @@ class SqueezeAndExcite(BaseModule): self, bn_momentum: float, bn_eps: float, - block: nn.Sequential, in_channels: int, channels: int, se_ratio: float, ) -> None: - super().__init__(bn_momentum, bn_eps, block) - self.in_channels = in_channels self.channels = channels self.se_ratio = se_ratio + super().__init__(bn_momentum, bn_eps) def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -136,13 +130,12 @@ class Pointwise(BaseModule): self, bn_momentum: float, bn_eps: float, - block: nn.Sequential, in_channels: int, out_channels: int, ) -> None: - super().__init__(bn_momentum, bn_eps, block) self.in_channels = in_channels self.out_channels = out_channels + super().__init__(bn_momentum, bn_eps) def _build(self) -> None: self.block = nn.Sequential( @@ -182,6 +175,7 @@ class MBConvBlock(nn.Module): self.bn_momentum = bn_momentum self.bn_eps = bn_eps self.se_ratio = se_ratio + self.pad = self._configure_padding() self.expand_ratio = expand_ratio self._inverted_bottleneck: Optional[InvertedBottleneck] self._depthwise: nn.Sequential @@ -189,6 +183,15 @@ class MBConvBlock(nn.Module): self._pointwise: nn.Sequential self._build() + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" + if self.stride == (2, 2): + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + def _build(self) -> None: has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 inner_channels = self.in_channels * self.expand_ratio @@ -249,6 +252,7 @@ class MBConvBlock(nn.Module): if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) + x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: |