summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/efficientnet/__init__.py (renamed from text_recognizer/networks/encoders/efficientnet/__init__.py)0
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py (renamed from text_recognizer/networks/encoders/efficientnet/efficientnet.py)16
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py (renamed from text_recognizer/networks/encoders/efficientnet/mbconv.py)0
-rw-r--r--text_recognizer/networks/efficientnet/utils.py (renamed from text_recognizer/networks/encoders/efficientnet/utils.py)3
-rw-r--r--text_recognizer/networks/encoders/__init__.py2
5 files changed, 13 insertions, 8 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py
index 344233f..344233f 100644
--- a/text_recognizer/networks/encoders/efficientnet/__init__.py
+++ b/text_recognizer/networks/efficientnet/__init__.py
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 9454514..4c9ed75 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -4,8 +4,8 @@ from typing import Tuple
import attr
from torch import nn, Tensor
-from .mbconv import MBConvBlock
-from .utils import (
+from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
+from text_recognizer.networks.efficientnet.utils import (
block_args,
round_filters,
round_repeats,
@@ -38,6 +38,7 @@ class EfficientNet(nn.Module):
stochastic_dropout_rate: float = attr.ib(default=0.2)
bn_momentum: float = attr.ib(default=0.99)
bn_eps: float = attr.ib(default=1.0e-3)
+ depth: int = attr.ib(default=7)
out_channels: int = attr.ib(default=None, init=False)
_conv_stem: nn.Sequential = attr.ib(default=None, init=False)
_blocks: nn.ModuleList = attr.ib(default=None, init=False)
@@ -47,8 +48,13 @@ class EfficientNet(nn.Module):
"""Post init configuration."""
self._build()
+ @depth.validator
+ def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ if not 5 <= value <= 7:
+ raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
+
@arch.validator
- def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -56,7 +62,7 @@ class EfficientNet(nn.Module):
def _build(self) -> None:
"""Builds the efficientnet backbone."""
- _block_args = block_args()
+ _block_args = block_args()[: self.depth]
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
self._conv_stem = nn.Sequential(
@@ -88,7 +94,7 @@ class EfficientNet(nn.Module):
args.in_channels = args.out_channels
args.stride = 1
- in_channels = round_filters(320, self.params)
+ in_channels = round_filters(_block_args[-1].out_channels, self.params)
self.out_channels = round_filters(1280, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index 4b051eb..4b051eb 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
index 2b1aebb..5234324 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -77,7 +77,8 @@ def block_args() -> List[DictConfig]:
[2, 5, (2, 2), 6, 24, 40, 0.25],
[3, 3, (2, 2), 6, 40, 80, 0.25],
[3, 5, (1, 1), 6, 80, 112, 0.25],
- [1, 3, (1, 1), 6, 112, 320, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
]
block_args_ = []
for row in args:
diff --git a/text_recognizer/networks/encoders/__init__.py b/text_recognizer/networks/encoders/__init__.py
deleted file mode 100644
index 25aed0e..0000000
--- a/text_recognizer/networks/encoders/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Vision backbones."""
-from .efficientnet import EfficientNet