summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/efficientnet/__init__.py (renamed from text_recognizer/networks/encoders/efficientnet/__init__.py)0
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py (renamed from text_recognizer/networks/encoders/efficientnet/efficientnet.py)16
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py (renamed from text_recognizer/networks/encoders/efficientnet/mbconv.py)0
-rw-r--r--text_recognizer/networks/efficientnet/utils.py (renamed from text_recognizer/networks/encoders/efficientnet/utils.py)3
-rw-r--r--text_recognizer/networks/encoders/__init__.py2
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml11
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml2
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs_wp.yaml2
-rw-r--r--training/conf/network/encoder/efficientnet.yaml2
9 files changed, 22 insertions, 16 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py
index 344233f..344233f 100644
--- a/text_recognizer/networks/encoders/efficientnet/__init__.py
+++ b/text_recognizer/networks/efficientnet/__init__.py
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 9454514..4c9ed75 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -4,8 +4,8 @@ from typing import Tuple
import attr
from torch import nn, Tensor
-from .mbconv import MBConvBlock
-from .utils import (
+from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
+from text_recognizer.networks.efficientnet.utils import (
block_args,
round_filters,
round_repeats,
@@ -38,6 +38,7 @@ class EfficientNet(nn.Module):
stochastic_dropout_rate: float = attr.ib(default=0.2)
bn_momentum: float = attr.ib(default=0.99)
bn_eps: float = attr.ib(default=1.0e-3)
+ depth: int = attr.ib(default=7)
out_channels: int = attr.ib(default=None, init=False)
_conv_stem: nn.Sequential = attr.ib(default=None, init=False)
_blocks: nn.ModuleList = attr.ib(default=None, init=False)
@@ -47,8 +48,13 @@ class EfficientNet(nn.Module):
"""Post init configuration."""
self._build()
+ @depth.validator
+ def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ if not 5 <= value <= 7:
+ raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
+
@arch.validator
- def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -56,7 +62,7 @@ class EfficientNet(nn.Module):
def _build(self) -> None:
"""Builds the efficientnet backbone."""
- _block_args = block_args()
+ _block_args = block_args()[: self.depth]
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
self._conv_stem = nn.Sequential(
@@ -88,7 +94,7 @@ class EfficientNet(nn.Module):
args.in_channels = args.out_channels
args.stride = 1
- in_channels = round_filters(320, self.params)
+ in_channels = round_filters(_block_args[-1].out_channels, self.params)
self.out_channels = round_filters(1280, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index 4b051eb..4b051eb 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
index 2b1aebb..5234324 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -77,7 +77,8 @@ def block_args() -> List[DictConfig]:
[2, 5, (2, 2), 6, 24, 40, 0.25],
[3, 3, (2, 2), 6, 40, 80, 0.25],
[3, 5, (1, 1), 6, 80, 112, 0.25],
- [1, 3, (1, 1), 6, 112, 320, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
]
block_args_ = []
for row in args:
diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py
deleted file mode 100644
index 25aed0e..0000000
--- a/text_recognizer/networks/encoders/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Vision backbones."""
-from .efficientnet import EfficientNet
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 6c266b8..2918317 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -64,10 +64,10 @@ rotary_embedding: &rotary_embedding
dim: 64
attn: &attn
- dim: &hidden_dim 256
+ dim: &hidden_dim 128
num_heads: 4
dim_head: 64
- dropout_rate: &dropout_rate 0.5
+ dropout_rate: &dropout_rate 0.2
network:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
@@ -76,8 +76,9 @@ network:
num_classes: *num_classes
pad_index: *ignore_index
encoder:
- _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
- arch: b1
+ _target_: text_recognizer.networks.efficientnet.EfficientNet
+ arch: b0
+ depth: 5
stochastic_dropout_rate: 0.2
bn_momentum: 0.99
bn_eps: 1.0e-3
@@ -113,7 +114,7 @@ network:
dim: *hidden_dim
heads: 4
shape: *shape
- depth: 1
+ depth: 2
dim_head: 64
dim_index: 1
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 4f15ef2..32f5763 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -78,7 +78,7 @@ network:
num_classes: *num_classes
pad_index: *ignore_index
encoder:
- _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+ _target_: text_recognizer.networks.efficientnet.EfficientNet
arch: b1
stochastic_dropout_rate: 0.2
bn_momentum: 0.99
diff --git a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
index 91fba9a..bf192ec 100644
--- a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
@@ -89,7 +89,7 @@ network:
num_classes: *num_classes
pad_index: *ignore_index
encoder:
- _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+ _target_: text_recognizer.networks.efficientnet.EfficientNet
arch: b0
out_channels: 1280
stochastic_dropout_rate: 0.2
diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml
index 0e62293..a7be069 100644
--- a/training/conf/network/encoder/efficientnet.yaml
+++ b/training/conf/network/encoder/efficientnet.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+_target_: text_recognizer.networks.efficientnet.EfficientNet
arch: b0
stochastic_dropout_rate: 0.2
bn_momentum: 0.99