diff options
-rw-r--r-- | text_recognizer/networks/efficientnet/__init__.py (renamed from text_recognizer/networks/encoders/efficientnet/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py (renamed from text_recognizer/networks/encoders/efficientnet/efficientnet.py) | 16 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py (renamed from text_recognizer/networks/encoders/efficientnet/mbconv.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/utils.py (renamed from text_recognizer/networks/encoders/efficientnet/utils.py) | 3 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/__init__.py | 2 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 11 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs_wp.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/encoder/efficientnet.yaml | 2 |
9 files changed, 22 insertions, 16 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py index 344233f..344233f 100644 --- a/text_recognizer/networks/encoders/efficientnet/__init__.py +++ b/text_recognizer/networks/efficientnet/__init__.py diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py index 9454514..4c9ed75 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -4,8 +4,8 @@ from typing import Tuple import attr from torch import nn, Tensor -from .mbconv import MBConvBlock -from .utils import ( +from text_recognizer.networks.efficientnet.mbconv import MBConvBlock +from text_recognizer.networks.efficientnet.utils import ( block_args, round_filters, round_repeats, @@ -38,6 +38,7 @@ class EfficientNet(nn.Module): stochastic_dropout_rate: float = attr.ib(default=0.2) bn_momentum: float = attr.ib(default=0.99) bn_eps: float = attr.ib(default=1.0e-3) + depth: int = attr.ib(default=7) out_channels: int = attr.ib(default=None, init=False) _conv_stem: nn.Sequential = attr.ib(default=None, init=False) _blocks: nn.ModuleList = attr.ib(default=None, init=False) @@ -47,8 +48,13 @@ class EfficientNet(nn.Module): """Post init configuration.""" self._build() + @depth.validator + def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + if not 5 <= value <= 7: + raise ValueError(f"Depth has to be between 5 and 7, was: {value}") + @arch.validator - def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") @@ -56,7 +62,7 @@ class EfficientNet(nn.Module): def _build(self) -> None: """Builds the efficientnet backbone.""" - _block_args = block_args() + _block_args = block_args()[: self.depth] in_channels = 1 # BW out_channels = round_filters(32, self.params) self._conv_stem = nn.Sequential( @@ -88,7 +94,7 @@ class EfficientNet(nn.Module): args.in_channels = args.out_channels args.stride = 1 - in_channels = round_filters(320, self.params) + in_channels = round_filters(_block_args[-1].out_channels, self.params) self.out_channels = round_filters(1280, self.params) self._conv_head = nn.Sequential( nn.Conv2d( diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py index 4b051eb..4b051eb 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/efficientnet/mbconv.py diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py index 2b1aebb..5234324 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/efficientnet/utils.py @@ -77,7 +77,8 @@ def block_args() -> List[DictConfig]: [2, 5, (2, 2), 6, 24, 40, 0.25], [3, 3, (2, 2), 6, 40, 80, 0.25], [3, 5, (1, 1), 6, 80, 112, 0.25], - [1, 3, (1, 1), 6, 112, 320, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], ] block_args_ = [] for row in args: diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py deleted file mode 100644 index 25aed0e..0000000 --- a/text_recognizer/networks/encoders/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Vision backbones.""" -from .efficientnet import EfficientNet diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 6c266b8..2918317 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -64,10 +64,10 @@ rotary_embedding: &rotary_embedding dim: 64 attn: &attn - dim: &hidden_dim 256 + dim: &hidden_dim 128 num_heads: 4 dim_head: 64 - dropout_rate: &dropout_rate 0.5 + dropout_rate: &dropout_rate 0.2 network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer @@ -76,8 +76,9 @@ network: num_classes: *num_classes pad_index: *ignore_index encoder: - _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet - arch: b1 + _target_: text_recognizer.networks.efficientnet.EfficientNet + arch: b0 + depth: 5 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 @@ -113,7 +114,7 @@ network: dim: *hidden_dim heads: 4 shape: *shape - depth: 1 + depth: 2 dim_head: 64 dim_index: 1 diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 4f15ef2..32f5763 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -78,7 +78,7 @@ network: num_classes: *num_classes pad_index: *ignore_index encoder: - _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet + _target_: text_recognizer.networks.efficientnet.EfficientNet arch: b1 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 diff --git a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml index 91fba9a..bf192ec 100644 --- a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml @@ -89,7 +89,7 @@ network: num_classes: *num_classes pad_index: *ignore_index encoder: - _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet + _target_: text_recognizer.networks.efficientnet.EfficientNet arch: b0 out_channels: 1280 stochastic_dropout_rate: 0.2 diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml index 0e62293..a7be069 100644 --- a/training/conf/network/encoder/efficientnet.yaml +++ b/training/conf/network/encoder/efficientnet.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet +_target_: text_recognizer.networks.efficientnet.EfficientNet arch: b0 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 |