summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/mbconv.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/mbconv.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py64
1 files changed, 32 insertions, 32 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index fbb3f22..e43771a 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,11 +1,11 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Tuple
+from typing import Any, Optional, Union, Tuple
import torch
from torch import nn, Tensor
-from torch.nn import functional as F
+import torch.nn.functional as F
-from .utils import calculate_output_image_size, drop_connection, get_same_padding_conv2d
+from .utils import stochastic_depth
class MBConvBlock(nn.Module):
@@ -16,22 +16,30 @@ class MBConvBlock(nn.Module):
in_channels: int,
out_channels: int,
kernel_size: int,
- stride: int,
+ stride: Union[Tuple[int, int], int],
bn_momentum: float,
bn_eps: float,
se_ratio: float,
expand_ratio: int,
- image_size: Optional[Tuple[int, int]],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__()
self.kernel_size = kernel_size
+ self.stride = (stride, ) * 2 if isinstance(stride, int) else stride
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
- self.in_channels = self.in_channels
+ self.in_channels = in_channels
self.out_channels = out_channels
+ if self.stride == (2, 2):
+ self.pad = [
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ] * 2
+ else:
+ self.pad = [(self.kernel_size - 1) // 2] * 4
+
# Placeholders for layers.
self._inverted_bottleneck: nn.Sequential = None
self._depthwise: nn.Sequential = None
@@ -39,7 +47,6 @@ class MBConvBlock(nn.Module):
self._pointwise: nn.Sequential = None
self._build(
- image_size=image_size,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
@@ -50,11 +57,10 @@ class MBConvBlock(nn.Module):
def _build(
self,
- image_size: Optional[Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
- stride: int,
+ stride: Union[Tuple[int, int], int],
expand_ratio: int,
se_ratio: float,
) -> None:
@@ -62,7 +68,6 @@ class MBConvBlock(nn.Module):
inner_channels = in_channels * expand_ratio
self._inverted_bottleneck = (
self._configure_inverted_bottleneck(
- image_size=image_size,
in_channels=in_channels,
out_channels=inner_channels,
)
@@ -71,7 +76,6 @@ class MBConvBlock(nn.Module):
)
self._depthwise = self._configure_depthwise(
- image_size=image_size,
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
@@ -79,7 +83,6 @@ class MBConvBlock(nn.Module):
stride=stride,
)
- image_size = calculate_output_image_size(image_size, stride)
self._squeeze_excite = (
self._configure_squeeze_excite(
in_channels=inner_channels,
@@ -91,19 +94,17 @@ class MBConvBlock(nn.Module):
)
self._pointwise = self._configure_pointwise(
- image_size=image_size, in_channels=inner_channels, out_channels=out_channels
+ in_channels=inner_channels, out_channels=out_channels
)
def _configure_inverted_bottleneck(
self,
- image_size: Optional[Tuple[int, int]],
in_channels: int,
out_channels: int,
) -> nn.Sequential:
"""Expansion phase."""
- Conv2d = get_same_padding_conv2d(image_size=image_size)
return nn.Sequential(
- Conv2d(
+ nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
@@ -117,16 +118,14 @@ class MBConvBlock(nn.Module):
def _configure_depthwise(
self,
- image_size: Optional[Tuple[int, int]],
in_channels: int,
out_channels: int,
groups: int,
kernel_size: int,
- stride: int,
+ stride: Union[Tuple[int, int], int],
) -> nn.Sequential:
- Conv2d = get_same_padding_conv2d(image_size=image_size)
return nn.Sequential(
- Conv2d(
+ nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
@@ -143,16 +142,15 @@ class MBConvBlock(nn.Module):
def _configure_squeeze_excite(
self, in_channels: int, out_channels: int, se_ratio: float
) -> nn.Sequential:
- Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(1, int(in_channels * se_ratio))
return nn.Sequential(
- Conv2d(
+ nn.Conv2d(
in_channels=in_channels,
out_channels=num_squeezed_channels,
kernel_size=1,
),
nn.Mish(inplace=True),
- Conv2d(
+ nn.Conv2d(
in_channels=num_squeezed_channels,
out_channels=out_channels,
kernel_size=1,
@@ -160,11 +158,10 @@ class MBConvBlock(nn.Module):
)
def _configure_pointwise(
- self, image_size: Optional[Tuple[int, int]], in_channels: int, out_channels: int
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- Conv2d = get_same_padding_conv2d(image_size=image_size)
return nn.Sequential(
- Conv2d(
+ nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
@@ -176,20 +173,23 @@ class MBConvBlock(nn.Module):
)
def _stochastic_depth(
- self, x: Tensor, residual: Tensor, drop_connection_rate: Optional[float]
+ self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
) -> Tensor:
- if self.id_skip and self.stride == 1 and self.in_channels == self.out_channels:
- if drop_connection_rate:
- x = drop_connection(x, p=drop_connection_rate, training=self.training)
+ if self.stride == (1, 1) and self.in_channels == self.out_channels:
+ if stochastic_dropout_rate:
+ x = stochastic_depth(
+ x, p=stochastic_dropout_rate, training=self.training
+ )
x += residual
return x
def forward(
- self, x: Tensor, drop_connection_rate: Optional[float] = None
+ self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
) -> Tensor:
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
@@ -201,5 +201,5 @@ class MBConvBlock(nn.Module):
x = self._pointwise(x)
# Stochastic depth
- x = self._stochastic_depth(x, residual, drop_connection_rate)
+ x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
return x