summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/mbconv.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-09 22:32:32 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-09 22:32:32 +0200
commit0738c29d88e78f8f464d5421e1f5f844ea54c2e7 (patch)
treed7ad478fb92830bcca127d51ea0a38c878aecd35 /text_recognizer/networks/efficientnet/mbconv.py
parent89e739b0a6818109b56e87a8403ec8abc32d4b7a (diff)
Refactor padding
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py17
1 files changed, 1 insertions, 16 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index 64debd9..e090542 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -8,11 +8,6 @@ import torch.nn.functional as F
from text_recognizer.networks.efficientnet.utils import stochastic_depth
-def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
- """Converts int to tuple."""
- return (stride,) * 2 if isinstance(stride, int) else stride
-
-
class BaseModule(nn.Module):
"""Base sub module class."""
@@ -89,6 +84,7 @@ class Depthwise(BaseModule):
kernel_size=self.kernel_size,
stride=self.stride,
groups=self.channels,
+ padding="same",
bias=False,
),
nn.BatchNorm2d(
@@ -187,22 +183,12 @@ class MBConvBlock(nn.Module):
self.bn_eps = bn_eps
self.se_ratio = se_ratio
self.expand_ratio = expand_ratio
- self.pad = self._configure_padding()
self._inverted_bottleneck: Optional[InvertedBottleneck]
self._depthwise: nn.Sequential
self._squeeze_excite: nn.Sequential
self._pointwise: nn.Sequential
self._build()
- def _configure_padding(self) -> Tuple[int, int, int, int]:
- """Set padding for convolutional layers."""
- if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
- return ((self.kernel_size - 1) // 2,) * 4
-
def _build(self) -> None:
has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
inner_channels = self.in_channels * self.expand_ratio
@@ -263,7 +249,6 @@ class MBConvBlock(nn.Module):
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None: