From 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 5 Jun 2022 23:39:11 +0200 Subject: Remove attrs --- .../networks/efficientnet/efficientnet.py | 47 ++++---- text_recognizer/networks/efficientnet/mbconv.py | 121 +++++++++++++-------- 2 files changed, 99 insertions(+), 69 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index cf64bcf..2260ee2 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,6 @@ """Efficientnet backbone.""" from typing import Tuple -from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" - def __attrs_pre_init__(self) -> None: - super().__init__() - archs = { # width, depth, dropout "b0": (1.0, 1.0, 0.2), @@ -33,32 +28,32 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = field() - params: Tuple[float, float, float] = field(default=None, init=False) - stochastic_dropout_rate: float = field(default=0.2) - bn_momentum: float = field(default=0.99) - bn_eps: float = field(default=1.0e-3) - depth: int = field(default=7) - out_channels: int = field(default=None, init=False) - _conv_stem: nn.Sequential = field(default=None, init=False) - _blocks: nn.ModuleList = field(default=None, init=False) - _conv_head: nn.Sequential = field(default=None, init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + arch: str, + params: Tuple[float, float, float], + stochastic_dropout_rate: float = 0.2, + bn_momentum: float = 0.99, + bn_eps: float = 1.0e-3, + depth: int = 7, + ) -> None: + super().__init__() + self.params = self._get_arch_params(arch) + self.stochastic_dropout_rate = stochastic_dropout_rate + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.depth = depth + self.out_channels: int + self._conv_stem: nn.Sequential + self._blocks: nn.ModuleList + self._conv_head: nn.Sequential self._build() - @depth.validator - def _check_depth(self, attribute, value: str) -> None: - if not 5 <= value <= 7: - raise ValueError(f"Depth has to be between 5 and 7, was: {value}") - - @arch.validator - def _check_arch(self, attribute, value: str) -> None: + def _get_arch_params(self, value: str) -> Tuple[float, float, float]: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") - self.params = self.archs[value] + return self.archs[value] def _build(self) -> None: """Builds the efficientnet backbone.""" diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 98e9353..64debd9 100644 --- a/text_recognizer/networks/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -1,7 +1,6 @@ """Mobile inverted residual block.""" from typing import Optional, Tuple, Union -from attrs import define, field import torch from torch import nn, Tensor import torch.nn.functional as F @@ -14,18 +13,15 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: return (stride,) * 2 if isinstance(stride, int) else stride -@define(eq=False) class BaseModule(nn.Module): """Base sub module class.""" - bn_momentum: float = field() - bn_eps: float = field() - block: nn.Sequential = field(init=False) - - def __attrs_pre_init__(self) -> None: + def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None: super().__init__() - def __attrs_post_init__(self) -> None: + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.block = block self._build() def _build(self) -> None: @@ -36,12 +32,20 @@ class BaseModule(nn.Module): return self.block(x) -@define(auto_attribs=True, eq=False) class InvertedBottleneck(BaseModule): """Inverted bottleneck module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -60,13 +64,22 @@ class InvertedBottleneck(BaseModule): ) -@define(auto_attribs=True, eq=False) class Depthwise(BaseModule): """Depthwise convolution module.""" - channels: int = field() - kernel_size: int = field() - stride: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + channels: int, + kernel_size: int, + stride: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.channels = channels + self.kernel_size = kernel_size + self.stride = stride def _build(self) -> None: self.block = nn.Sequential( @@ -85,13 +98,23 @@ class Depthwise(BaseModule): ) -@define(auto_attribs=True, eq=False) class SqueezeAndExcite(BaseModule): """Sequeeze and excite module.""" - in_channels: int = field() - channels: int = field() - se_ratio: float = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + channels: int, + se_ratio: float, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + + self.in_channels = in_channels + self.channels = channels + self.se_ratio = se_ratio def _build(self) -> None: num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) @@ -110,12 +133,20 @@ class SqueezeAndExcite(BaseModule): ) -@define(auto_attribs=True, eq=False) class Pointwise(BaseModule): """Pointwise module.""" - in_channels: int = field() - out_channels: int = field() + def __init__( + self, + bn_momentum: float, + bn_eps: float, + block: nn.Sequential, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__(bn_momentum, bn_eps, block) + self.in_channels = in_channels + self.out_channels = out_channels def _build(self) -> None: self.block = nn.Sequential( @@ -133,28 +164,36 @@ class Pointwise(BaseModule): ) -@define(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __attrs_pre_init__(self) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + stride: Tuple[int, int], + bn_momentum: float, + bn_eps: float, + se_ratio: float, + expand_ratio: int, + ) -> None: super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.se_ratio = se_ratio + self.expand_ratio = expand_ratio + self.pad = self._configure_padding() + self._inverted_bottleneck: Optional[InvertedBottleneck] + self._depthwise: nn.Sequential + self._squeeze_excite: nn.Sequential + self._pointwise: nn.Sequential + self._build() - in_channels: int = field() - out_channels: int = field() - kernel_size: Tuple[int, int] = field() - stride: Tuple[int, int] = field(converter=_convert_stride) - bn_momentum: float = field() - bn_eps: float = field() - se_ratio: float = field() - expand_ratio: int = field() - pad: Tuple[int, int, int, int] = field(init=False) - _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False) - _depthwise: nn.Sequential = field(init=False) - _squeeze_excite: nn.Sequential = field(init=False) - _pointwise: nn.Sequential = field(init=False) - - @pad.default def _configure_padding(self) -> Tuple[int, int, int, int]: """Set padding for convolutional layers.""" if self.stride == (2, 2): @@ -164,10 +203,6 @@ class MBConvBlock(nn.Module): ) * 2 return ((self.kernel_size - 1) // 2,) * 4 - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self._build() - def _build(self) -> None: has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 inner_channels = self.in_channels * self.expand_ratio -- cgit v1.2.3-70-g09d2