diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 121 |
1 files changed, 78 insertions, 43 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 98e9353..64debd9 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,6 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,18 +13,15 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = field() - bn_eps: float = field() - block: nn.Sequential = field(init=False) - - def __attrs_pre_init__(self) -> None: + def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None: super().__init__() - def __attrs_post_init__(self) -> None: + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.block = block self._build() def _build(self) -> None: @@ -36,12 +32,20 @@ class BaseModule(nn.Module): return self.block(x) -@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +64,22 @@ class InvertedBottleneck(BaseModule): ) -@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = field() - kernel_size: int = field() - stride: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + channels: int, + kernel_size: int, + stride: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.channels = channels + self.kernel_size = kernel_size + self.stride = stride def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +98,23 @@ class Depthwise(BaseModule): ) -@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = field() - channels: int = field() - se_ratio: float = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + channels: int, + se_ratio: float, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + + self.in_channels = in_channels + self.channels = channels + self.se_ratio = se_ratio def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +133,20 @@ class SqueezeAndExcite(BaseModule): ) -@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -133,28 +164,36 @@ class Pointwise(BaseModule): ) -@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + bn_momentum: float, + bn_eps: float, + se_ratio: float, + expand_ratio: int, + ) -> None: super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.se_ratio = se_ratio + self.expand_ratio = expand_ratio + self.pad = self._configure_padding() + self._inverted_bottleneck: Optional[InvertedBottleneck] + self._depthwise: nn.Sequential + self._squeeze_excite: nn.Sequential + self._pointwise: nn.Sequential + self._build() - in_channels: int = field() - out_channels: int = field() - kernel_size: Tuple[int, int] = field() - stride: Tuple[int, int] = field(converter=_convert_stride) - bn_momentum: float = field() - bn_eps: float = field() - se_ratio: float = field() - expand_ratio: int = field() - pad: Tuple[int, int, int, int] = field(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) - _depthwise: nn.Sequential = field(init=False) - _squeeze_excite: nn.Sequential = field(init=False) - _pointwise: nn.Sequential = field(init=False) - - @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): @@ -164,10 +203,6 @@ class MBConvBlock(nn.Module): ) * 2 return ((self.kernel_size - 1) // 2,) * 4 - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self._build() - def _build(self) -> None: has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 inner_channels = self.in_channels * self.expand_ratio |