summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py47
1 files changed, 21 insertions, 26 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index cf64bcf..2260ee2 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,6 @@
"""Efficientnet backbone."""
from typing import Tuple
-from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
archs = {
# width, depth, dropout
"b0": (1.0, 1.0, 0.2),
@@ -33,32 +28,32 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = field()
- params: Tuple[float, float, float] = field(default=None, init=False)
- stochastic_dropout_rate: float = field(default=0.2)
- bn_momentum: float = field(default=0.99)
- bn_eps: float = field(default=1.0e-3)
- depth: int = field(default=7)
- out_channels: int = field(default=None, init=False)
- _conv_stem: nn.Sequential = field(default=None, init=False)
- _blocks: nn.ModuleList = field(default=None, init=False)
- _conv_head: nn.Sequential = field(default=None, init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(
+ self,
+ arch: str,
+ params: Tuple[float, float, float],
+ stochastic_dropout_rate: float = 0.2,
+ bn_momentum: float = 0.99,
+ bn_eps: float = 1.0e-3,
+ depth: int = 7,
+ ) -> None:
+ super().__init__()
+ self.params = self._get_arch_params(arch)
+ self.stochastic_dropout_rate = stochastic_dropout_rate
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.depth = depth
+ self.out_channels: int
+ self._conv_stem: nn.Sequential
+ self._blocks: nn.ModuleList
+ self._conv_head: nn.Sequential
self._build()
- @depth.validator
- def _check_depth(self, attribute, value: str) -> None:
- if not 5 <= value <= 7:
- raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
-
- @arch.validator
- def _check_arch(self, attribute, value: str) -> None:
+ def _get_arch_params(self, value: str) -> Tuple[float, float, float]:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
- self.params = self.archs[value]
+ return self.archs[value]
def _build(self) -> None:
"""Builds the efficientnet backbone."""