diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 4c9ed75..cf64bcf 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,7 @@ """Efficientnet backbone.""" from typing import Tuple -import attr +from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@attr.s(eq=False) +@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" @@ -33,28 +33,28 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = attr.ib() - params: Tuple[float, float, float] = attr.ib(default=None, init=False) - stochastic_dropout_rate: float = attr.ib(default=0.2) - bn_momentum: float = attr.ib(default=0.99) - bn_eps: float = attr.ib(default=1.0e-3) - depth: int = attr.ib(default=7) - out_channels: int = attr.ib(default=None, init=False) - _conv_stem: nn.Sequential = attr.ib(default=None, init=False) - _blocks: nn.ModuleList = attr.ib(default=None, init=False) - _conv_head: nn.Sequential = attr.ib(default=None, init=False) + arch: str = field() + params: Tuple[float, float, float] = field(default=None, init=False) + stochastic_dropout_rate: float = field(default=0.2) + bn_momentum: float = field(default=0.99) + bn_eps: float = field(default=1.0e-3) + depth: int = field(default=7) + out_channels: int = field(default=None, init=False) + _conv_stem: nn.Sequential = field(default=None, init=False) + _blocks: nn.ModuleList = field(default=None, init=False) + _conv_head: nn.Sequential = field(default=None, init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self._build() @depth.validator - def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_depth(self, attribute, value: str) -> None: if not 5 <= value <= 7: raise ValueError(f"Depth has to be between 5 and 7, was: {value}") @arch.validator - def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -88,7 +88,9 @@ class EfficientNet(nn.Module): for _ in range(num_repeats): self._blocks.append( MBConvBlock( - **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, + **args, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, ) ) args.in_channels = args.out_channels |