summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-07 23:38:48 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-07 23:38:48 +0100
commit5b7d6eec4160f7b3ea0862915b9428b37a3d49fa (patch)
treebbcaf670def50f81bb2729b543d9a098cdce1f09 /text_recognizer/networks/encoders
parentf8db5953a63c2467a76cc0610eb2ad0c96b69c70 (diff)
Fix mbconv sub modules
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py207
1 files changed, 139 insertions, 68 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index f01c369..96626fc 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,5 +1,5 @@
"""Mobile inverted residual block."""
-from typing import Optional, Sequence, Union, Tuple
+from typing import Optional, Tuple, Union
import attr
import torch
@@ -15,119 +15,112 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
@attr.s(eq=False)
-class MBConvBlock(nn.Module):
- """Mobile Inverted Residual Bottleneck block."""
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
+class BaseModule(nn.Module):
+ """Base sub module class."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- kernel_size: Tuple[int, int] = attr.ib()
- stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
bn_momentum: float = attr.ib()
bn_eps: float = attr.ib()
- se_ratio: float = attr.ib()
- expand_ratio: int = attr.ib()
- pad: Tuple[int, int, int, int] = attr.ib(init=False)
- _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
- _depthwise: nn.Sequential = attr.ib(init=False)
- _squeeze_excite: nn.Sequential = attr.ib(init=False)
- _pointwise: nn.Sequential = attr.ib(init=False)
+ block: nn.Sequential = attr.ib(init=False)
- @pad.default
- def _configure_padding(self) -> Tuple[int, int, int, int]:
- """Set padding for convolutional layers."""
- if self.stride == (2, 2):
- return (
- (self.kernel_size - 1) // 2 - 1,
- (self.kernel_size - 1) // 2,
- ) * 2
- return ((self.kernel_size - 1) // 2,) * 4
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
def __attrs_post_init__(self) -> None:
- """Post init configuration."""
self._build()
def _build(self) -> None:
- has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
- inner_channels = self.in_channels * self.expand_ratio
- self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(out_channels=inner_channels)
- if self.expand_ratio != 1
- else None
- )
+ pass
- self._depthwise = self._configure_depthwise(
- channels=inner_channels,
- groups=inner_channels,
- )
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.block(x)
- self._squeeze_excite = (
- self._configure_squeeze_excite(
- channels=inner_channels,
- )
- if has_se
- else None
- )
- self._pointwise = self._configure_pointwise(in_channels=inner_channels)
+@attr.s(auto_attribs=True, eq=False)
+class InvertedBottleneck(BaseModule):
+ """Inverted bottleneck module."""
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
- def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
- """Expansion phase."""
- return nn.Sequential(
+ def _build(self) -> None:
+ self.block = nn.Sequential(
nn.Conv2d(
in_channels=self.in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
nn.Mish(inplace=True),
)
- def _configure_depthwise(
- self,
- channels: int,
- groups: int,
- ) -> nn.Sequential:
- return nn.Sequential(
+
+@attr.s(auto_attribs=True, eq=False)
+class Depthwise(BaseModule):
+ """Depthwise convolution module."""
+
+ channels: int = attr.ib()
+ kernel_size: int = attr.ib()
+ stride: int = attr.ib()
+
+ def _build(self) -> None:
+ self.block = nn.Sequential(
nn.Conv2d(
- in_channels=channels,
- out_channels=channels,
+ in_channels=self.channels,
+ out_channels=self.channels,
kernel_size=self.kernel_size,
stride=self.stride,
- groups=groups,
+ groups=self.channels,
bias=False,
),
nn.BatchNorm2d(
- num_features=channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps
),
nn.Mish(inplace=True),
)
- def _configure_squeeze_excite(self, channels: int) -> nn.Sequential:
+
+@attr.s(auto_attribs=True, eq=False)
+class SqueezeAndExcite(BaseModule):
+ """Sequeeze and excite module."""
+
+ in_channels: int = attr.ib()
+ channels: int = attr.ib()
+ se_ratio: float = attr.ib()
+
+ def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
- return nn.Sequential(
+ self.block = nn.Sequential(
nn.Conv2d(
- in_channels=channels,
+ in_channels=self.channels,
out_channels=num_squeezed_channels,
kernel_size=1,
),
nn.Mish(inplace=True),
nn.Conv2d(
in_channels=num_squeezed_channels,
- out_channels=channels,
+ out_channels=self.channels,
kernel_size=1,
),
)
- def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
- return nn.Sequential(
+
+@attr.s(auto_attribs=True, eq=False)
+class Pointwise(BaseModule):
+ """Pointwise module."""
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+
+ def _build(self) -> None:
+ self.block = nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
bias=False,
@@ -139,6 +132,83 @@ class MBConvBlock(nn.Module):
),
)
+
+@attr.s(eq=False)
+class MBConvBlock(nn.Module):
+ """Mobile Inverted Residual Bottleneck block."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
+ if self.stride == (2, 2):
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
+
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
+ self._inverted_bottleneck = (
+ InvertedBottleneck(
+ in_channels=self.in_channels,
+ out_channels=inner_channels,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+ if self.expand_ratio != 1
+ else None
+ )
+
+ self._depthwise = Depthwise(
+ channels=inner_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+
+ self._squeeze_excite = (
+ SqueezeAndExcite(
+ in_channels=self.in_channels,
+ channels=inner_channels,
+ se_ratio=self.se_ratio,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+ if has_se
+ else None
+ )
+
+ self._pointwise = Pointwise(
+ in_channels=inner_channels,
+ out_channels=self.out_channels,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+
def _stochastic_depth(
self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
) -> Tensor:
@@ -153,6 +223,7 @@ class MBConvBlock(nn.Module):
def forward(
self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
) -> Tensor:
+ """Forward pass."""
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)