From 0738c29d88e78f8f464d5421e1f5f844ea54c2e7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 9 Jun 2022 22:32:32 +0200 Subject: Refactor padding --- text_recognizer/networks/efficientnet/mbconv.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 64debd9..e090542 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -8,11 +8,6 @@ import torch.nn.functional as F from text_recognizer.networks.efficientnet.utils import stochastic_depth -def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: - """Converts int to tuple.""" - return (stride,) * 2 if isinstance(stride, int) else stride - - class BaseModule(nn.Module): """Base sub module class.""" @@ -89,6 +84,7 @@ class Depthwise(BaseModule): kernel_size=self.kernel_size, stride=self.stride, groups=self.channels, + padding="same", bias=False, ), nn.BatchNorm2d( @@ -187,22 +183,12 @@ class MBConvBlock(nn.Module): self.bn_eps = bn_eps self.se_ratio = se_ratio self.expand_ratio = expand_ratio - self.pad = self._configure_padding() self._inverted_bottleneck: Optional[InvertedBottleneck] self._depthwise: nn.Sequential self._squeeze_excite: nn.Sequential self._pointwise: nn.Sequential self._build() - def _configure_padding(self) -> Tuple[int, int, int, int]: - """Set padding for convolutional layers.""" - if self.stride == (2, 2): - return ( - (self.kernel_size - 1) // 2 - 1, - (self.kernel_size - 1) // 2, - ) * 2 - return ((self.kernel_size - 1) // 2,) * 4 - def _build(self) -> None: has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 inner_channels = self.in_channels * self.expand_ratio @@ -263,7 +249,6 @@ class MBConvBlock(nn.Module): if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) - x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: -- cgit v1.2.3-70-g09d2