summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/mbconv.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e /text_recognizer/networks/encoders/efficientnet/mbconv.py
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/mbconv.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py139
1 files changed, 56 insertions, 83 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 3aa63d0..e85df87 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,76 +1,62 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Union, Tuple
+from typing import Optional, Sequence, Union, Tuple
+import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from .utils import stochastic_depth
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (
+ (stride,) * 2 if isinstance(stride, int) else stride
+ )
+
+
+@attr.s(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- *args: Any,
- **kwargs: Any,
- ) -> None:
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.kernel_size = kernel_size
- self.stride = (stride,) * 2 if isinstance(stride, int) else stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.in_channels = in_channels
- self.out_channels = out_channels
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
if self.stride == (2, 2):
- self.pad = [
+ return (
(self.kernel_size - 1) // 2 - 1,
(self.kernel_size - 1) // 2,
- ] * 2
- else:
- self.pad = [(self.kernel_size - 1) // 2] * 4
-
- # Placeholders for layers.
- self._inverted_bottleneck: nn.Sequential = None
- self._depthwise: nn.Sequential = None
- self._squeeze_excite: nn.Sequential = None
- self._pointwise: nn.Sequential = None
-
- self._build(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- expand_ratio=expand_ratio,
- se_ratio=se_ratio,
- )
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
- def _build(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- expand_ratio: int,
- se_ratio: float,
- ) -> None:
- has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
- inner_channels = in_channels * expand_ratio
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(
- in_channels=in_channels, out_channels=inner_channels,
- )
- if expand_ratio != 1
+ self._configure_inverted_bottleneck(out_channels=inner_channels)
+ if self.expand_ratio != 1
else None
)
@@ -78,31 +64,23 @@ class MBConvBlock(nn.Module):
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
- kernel_size=kernel_size,
- stride=stride,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels,
- out_channels=inner_channels,
- se_ratio=se_ratio,
+ in_channels=inner_channels, out_channels=inner_channels,
)
if has_se
else None
)
- self._pointwise = self._configure_pointwise(
- in_channels=inner_channels, out_channels=out_channels
- )
+ self._pointwise = self._configure_pointwise(in_channels=inner_channels)
- def _configure_inverted_bottleneck(
- self, in_channels: int, out_channels: int,
- ) -> nn.Sequential:
+ def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
@@ -114,19 +92,14 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self,
- in_channels: int,
- out_channels: int,
- groups: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
+ self, in_channels: int, out_channels: int, groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
groups=groups,
bias=False,
),
@@ -137,9 +110,9 @@ class MBConvBlock(nn.Module):
)
def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int, se_ratio: float
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * se_ratio))
+ num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -154,18 +127,18 @@ class MBConvBlock(nn.Module):
),
)
- def _configure_pointwise(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
+ def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)
@@ -186,8 +159,8 @@ class MBConvBlock(nn.Module):
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None: