summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-16 22:05:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-16 22:05:48 +0200
commit668ac89736364f7960baf51ccf8d65d69d6bd71e (patch)
tree85a9567253a6b6735caaebd46937b8ce133acc68
parent85953dcbf4893653311d9a45b127d74e76af4ad3 (diff)
Updates to mbconvblock
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv_block.py62
1 files changed, 45 insertions, 17 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py b/text_recognizer/networks/encoders/efficientnet/mbconv_block.py
index 0384cd9..c501777 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv_block.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv_block.py
@@ -1,5 +1,5 @@
"""Mobile inverted residual block."""
-from typing import Tuple
+from typing import Optional, Tuple
import torch
from torch import nn, Tensor
@@ -14,6 +14,7 @@ class MBConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
+ out_channels: int,
kernel_size: int,
stride: int,
bn_momentum: float,
@@ -28,12 +29,36 @@ class MBConvBlock(nn.Module):
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
self.id_skip = id_skip
- self.has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
-
+ (
+ self._inverted_bottleneck,
+ self._depthwise,
+ self._squeeze_excite,
+ self._pointwise,
+ ) = self._build(
+ image_size=image_size,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ expand_ratio=expand_ratio,
+ se_ratio=se_ratio,
+ )
- def _build(self, image_size: Tuple[int, int], in_channels: int, kernel_size: int, stride: int, expand_ratio: int) -> None:
+ def _build(
+ self,
+ image_size: Tuple[int, int],
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ expand_ratio: int,
+ se_ratio: float,
+ ) -> Tuple[
+ Optional[nn.Sequential], nn.Sequential, Optional[nn.Sequential], nn.Sequential
+ ]:
+ has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
inner_channels = in_channels * expand_ratio
- self._inverted_bottleneck = (
+ inverted_bottleneck = (
self._configure_inverted_bottleneck(
image_size=image_size,
in_channels=in_channels,
@@ -43,9 +68,9 @@ class MBConvBlock(nn.Module):
else None
)
- self._depthwise = self._configure_depthwise(
+ depthwise = self._configure_depthwise(
image_size=image_size,
- in_channels=in_channels,
+ in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
kernel_size=kernel_size,
@@ -53,17 +78,20 @@ class MBConvBlock(nn.Module):
)
image_size = calculate_output_image_size(image_size, stride)
- self._squeeze_excite = (
+ squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels, out_channels=inner_channels, se_ratio=se_ratio
+ in_channels=inner_channels,
+ out_channels=inner_channels,
+ se_ratio=se_ratio,
)
- if self.has_se
+ if has_se
else None
)
- self._pointwise = self._configure_pointwise(
+ pointwise = self._configure_pointwise(
image_size=image_size, in_channels=inner_channels, out_channels=out_channels
)
+ return inverted_bottleneck, depthwise, squeeze_excite, pointwise
def _configure_inverted_bottleneck(
self,
@@ -83,7 +111,7 @@ class MBConvBlock(nn.Module):
nn.BatchNorm2d(
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
- nn.SiLU(inplace=True),
+ nn.Mish(inplace=True),
)
def _configure_depthwise(
@@ -108,7 +136,7 @@ class MBConvBlock(nn.Module):
nn.BatchNorm2d(
num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
),
- nn.SiLU(inplace=True),
+ nn.Mish(inplace=True),
)
def _configure_squeeze_excite(
@@ -122,7 +150,7 @@ class MBConvBlock(nn.Module):
out_channels=num_squeezed_channels,
kernel_size=1,
),
- nn.SiLU(inplace=True),
+ nn.Mish(inplace=True),
Conv2d(
in_channels=num_squeezed_channels,
out_channels=out_channels,
@@ -156,8 +184,8 @@ class MBConvBlock(nn.Module):
if self._squeeze_excite is not None:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._squeeze_excite(x)
- x = torch.sigmoid(x_squeezed) * x
-
+ x = torch.tanh(F.softplus(x_squeezed)) * x
+
x = self._pointwise(x)
-
+
# Stochastic depth