diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
commit | db86cef2d308f58325278061c6aa177a535e7e03 (patch) | |
tree | a013fa85816337269f9cdc5a8992813fa62d299d /text_recognizer/networks/efficientnet/mbconv.py | |
parent | b980a281712a5b1ee7ee5bd8f5d4762cd91a070b (diff) |
Replace attr with attrs
Diffstat (limited to 'text_recognizer/networks/efficientnet/mbconv.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 71 |
1 files changed, 37 insertions, 34 deletions
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index beb7d57..98e9353 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,7 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -import attr +from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@attr.s(eq=False) +@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - block: nn.Sequential = attr.ib(init=False) + bn_momentum: float = field() + bn_eps: float = field() + block: nn.Sequential = field(init=False) def __attrs_pre_init__(self) -> None: super().__init__() @@ -36,12 +36,12 @@ class BaseModule(nn.Module): return self.block(x) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = attr.ib() - kernel_size: int = attr.ib() - stride: int = attr.ib() + channels: int = field() + kernel_size: int = field() + stride: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +85,13 @@ class Depthwise(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = attr.ib() - channels: int = attr.ib() - se_ratio: float = attr.ib() + in_channels: int = field() + channels: int = field() + se_ratio: float = field() def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule): ) -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = attr.ib() - out_channels: int = attr.ib() + in_channels: int = field() + out_channels: int = field() def _build(self) -> None: self.block = nn.Sequential( @@ -133,32 +133,35 @@ class Pointwise(BaseModule): ) -@attr.s(eq=False) +@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" def __attrs_pre_init__(self) -> None: super().__init__() - in_channels: int = attr.ib() - out_channels: int = attr.ib() - kernel_size: Tuple[int, int] = attr.ib() - stride: Tuple[int, int] = attr.ib(converter=_convert_stride) - bn_momentum: float = attr.ib() - bn_eps: float = attr.ib() - se_ratio: float = attr.ib() - expand_ratio: int = attr.ib() - pad: Tuple[int, int, int, int] = attr.ib(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) - _depthwise: nn.Sequential = attr.ib(init=False) - _squeeze_excite: nn.Sequential = attr.ib(init=False) - _pointwise: nn.Sequential = attr.ib(init=False) + in_channels: int = field() + out_channels: int = field() + kernel_size: Tuple[int, int] = field() + stride: Tuple[int, int] = field(converter=_convert_stride) + bn_momentum: float = field() + bn_eps: float = field() + se_ratio: float = field() + expand_ratio: int = field() + pad: Tuple[int, int, int, int] = field(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) + _depthwise: nn.Sequential = field(init=False) + _squeeze_excite: nn.Sequential = field(init=False) + _pointwise: nn.Sequential = field(init=False) @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): - return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ( + (self.kernel_size - 1) // 2 - 1, + (self.kernel_size - 1) // 2, + ) * 2 return ((self.kernel_size - 1) // 2,) * 4 def __attrs_post_init__(self) -> None: |