diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:43 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:43 +0100 |
commit | 050e1bd284a173d2586ad4607e95d114691db563 (patch) | |
tree | f428b5a17396a7cd585e89d84765e3d7ed233618 /text_recognizer/networks | |
parent | 36875ceca1e00f5bb39a151c50ecf5e333b4cf79 (diff) |
Move efficientnet from encoder dir
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/efficientnet/__init__.py (renamed from text_recognizer/networks/encoders/efficientnet/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py (renamed from text_recognizer/networks/encoders/efficientnet/efficientnet.py) | 16 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py (renamed from text_recognizer/networks/encoders/efficientnet/mbconv.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/utils.py (renamed from text_recognizer/networks/encoders/efficientnet/utils.py) | 3 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/__init__.py | 2 |
5 files changed, 13 insertions, 8 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py index 344233f..344233f 100644 --- a/text_recognizer/networks/encoders/efficientnet/__init__.py +++ b/text_recognizer/networks/efficientnet/__init__.py diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 9454514..4c9ed75 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -4,8 +4,8 @@ from typing import Tuple import attr from torch import nn, Tensor -from .mbconv import MBConvBlock -from .utils import ( +from text_recognizer.networks.efficientnet.mbconv import MBConvBlock +from text_recognizer.networks.efficientnet.utils import ( block_args, round_filters, round_repeats, @@ -38,6 +38,7 @@ class EfficientNet(nn.Module): stochastic_dropout_rate: float = attr.ib(default=0.2) bn_momentum: float = attr.ib(default=0.99) bn_eps: float = attr.ib(default=1.0e-3) + depth: int = attr.ib(default=7) out_channels: int = attr.ib(default=None, init=False) _conv_stem: nn.Sequential = attr.ib(default=None, init=False) _blocks: nn.ModuleList = attr.ib(default=None, init=False) @@ -47,8 +48,13 @@ class EfficientNet(nn.Module): """Post init configuration.""" self._build() + @depth.validator + def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + if not 5 <= value <= 7: + raise ValueError(f"Depth has to be between 5 and 7, was: {value}") + @arch.validator - def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -56,7 +62,7 @@ class EfficientNet(nn.Module): def _build(self) -> None: """Builds the efficientnet backbone.""" - _block_args = block_args() + _block_args = block_args()[: self.depth] in_channels = 1 # BW out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( @@ -88,7 +94,7 @@ class EfficientNet(nn.Module): args.in_channels = args.out_channels args.stride = 1 - in_channels = round_filters(320, self.params) + in_channels = round_filters(_block_args[-1].out_channels, self.params) self.out_channels = round_filters(1280, self.params) self._conv_head = nn.Sequential( nn.Conv2d( diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 4b051eb..4b051eb 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py index 2b1aebb..5234324 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/efficientnet/utils.py @@ -77,7 +77,8 @@ def block_args() -> List[DictConfig]: [2, 5, (2, 2), 6, 24, 40, 0.25], [3, 3, (2, 2), 6, 40, 80, 0.25], [3, 5, (1, 1), 6, 80, 112, 0.25], - [1, 3, (1, 1), 6, 112, 320, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], ] block_args_ = [] for row in args: diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py deleted file mode 100644 index 25aed0e..0000000 --- a/text_recognizer/networks/encoders/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Vision backbones.""" -from .efficientnet import EfficientNet |