diff options
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 58 |
1 files changed, 32 insertions, 26 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index 6719efb..a36150a 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,7 @@ """Efficient net.""" +from typing import Tuple + +import attr from torch import nn, Tensor from .mbconv import MBConvBlock @@ -9,10 +12,13 @@ from .utils import ( ) +@attr.s class EfficientNet(nn.Module): - # TODO: attr + def __attrs_pre_init__(self) -> None: + super().__init__() + archs = { - # width,depth0res,dropout + # width, depth, dropout "b0": (1.0, 1.0, 0.2), "b1": (1.0, 1.1, 0.2), "b2": (1.1, 1.2, 0.3), @@ -25,30 +31,30 @@ class EfficientNet(nn.Module): "l2": (4.3, 5.3, 0.5), } - def __init__( - self, - arch: str, - out_channels: int = 1280, - stochastic_dropout_rate: float = 0.2, - bn_momentum: float = 0.99, - bn_eps: float = 1.0e-3, - ) -> None: - super().__init__() - assert arch in self.archs, f"{arch} not a valid efficient net architecure!" - self.arch = self.archs[arch] - self.out_channels = out_channels - self.stochastic_dropout_rate = stochastic_dropout_rate - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self._conv_stem: nn.Sequential = None - self._blocks: nn.ModuleList = None - self._conv_head: nn.Sequential = None + arch: str = attr.ib() + params: Tuple[float, float, float] = attr.ib(default=None, init=False) + out_channels: int = attr.ib(default=1280) + stochastic_dropout_rate: float = attr.ib(default=0.2) + bn_momentum: float = attr.ib(default=0.99) + bn_eps: float = attr.ib(default=1.0e-3) + _conv_stem: nn.Sequential = attr.ib(default=None, init=False) + _blocks: nn.ModuleList = attr.ib(default=None, init=False) + _conv_head: nn.Sequential = attr.ib(default=None, init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" self._build() + @arch.validator + def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + if value not in self.archs: + raise ValueError(f"{value} not a valid architecure.") + self.params = self.archs[value] + def _build(self) -> None: _block_args = block_args() in_channels = 1 # BW - out_channels = round_filters(32, self.arch) + out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d( @@ -65,9 +71,9 @@ class EfficientNet(nn.Module): ) self._blocks = nn.ModuleList([]) for args in _block_args: - args.in_channels = round_filters(args.in_channels, self.arch) - args.out_channels = round_filters(args.out_channels, self.arch) - args.num_repeats = round_repeats(args.num_repeats, self.arch) + args.in_channels = round_filters(args.in_channels, self.params) + args.out_channels = round_filters(args.out_channels, self.params) + args.num_repeats = round_repeats(args.num_repeats, self.params) for _ in range(args.num_repeats): self._blocks.append( MBConvBlock( @@ -77,8 +83,8 @@ class EfficientNet(nn.Module): args.in_channels = args.out_channels args.stride = 1 - in_channels = round_filters(320, self.arch) - out_channels = round_filters(self.out_channels, self.arch) + in_channels = round_filters(320, self.params) + out_channels = round_filters(self.out_channels, self.params) self._conv_head = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d( |