diff options
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 47 |
1 files changed, 21 insertions, 26 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index cf64bcf..2260ee2 100644 --- a/text_recognizer/networks/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -1,7 +1,6 @@ """Efficientnet backbone.""" from typing import Tuple -from attrs import define, field from torch import nn, Tensor from text_recognizer.networks.efficientnet.mbconv import MBConvBlock @@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import ( ) -@define(eq=False) class EfficientNet(nn.Module): """Efficientnet without classification head.""" - def __attrs_pre_init__(self) -> None: - super().__init__() - archs = { # width, depth, dropout "b0": (1.0, 1.0, 0.2), @@ -33,32 +28,32 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - arch: str = field() - params: Tuple[float, float, float] = field(default=None, init=False) - stochastic_dropout_rate: float = field(default=0.2) - bn_momentum: float = field(default=0.99) - bn_eps: float = field(default=1.0e-3) - depth: int = field(default=7) - out_channels: int = field(default=None, init=False) - _conv_stem: nn.Sequential = field(default=None, init=False) - _blocks: nn.ModuleList = field(default=None, init=False) - _conv_head: nn.Sequential = field(default=None, init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + arch: str, + params: Tuple[float, float, float], + stochastic_dropout_rate: float = 0.2, + bn_momentum: float = 0.99, + bn_eps: float = 1.0e-3, + depth: int = 7, + ) -> None: + super().__init__() + self.params = self._get_arch_params(arch) + self.stochastic_dropout_rate = stochastic_dropout_rate + self.bn_momentum = bn_momentum + self.bn_eps = bn_eps + self.depth = depth + self.out_channels: int + self._conv_stem: nn.Sequential + self._blocks: nn.ModuleList + self._conv_head: nn.Sequential self._build() - @depth.validator - def _check_depth(self, attribute, value: str) -> None: - if not 5 <= value <= 7: - raise ValueError(f"Depth has to be between 5 and 7, was: {value}") - - @arch.validator - def _check_arch(self, attribute, value: str) -> None: + def _get_arch_params(self, value: str) -> Tuple[float, float, float]: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") - self.params = self.archs[value] + return self.archs[value] def _build(self) -> None: """Builds the efficientnet backbone.""" |