summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 4c9ed75..cf64bcf 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,7 @@
"""Efficientnet backbone."""
from typing import Tuple
-import attr
+from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@attr.s(eq=False)
+@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
@@ -33,28 +33,28 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = attr.ib()
- params: Tuple[float, float, float] = attr.ib(default=None, init=False)
- stochastic_dropout_rate: float = attr.ib(default=0.2)
- bn_momentum: float = attr.ib(default=0.99)
- bn_eps: float = attr.ib(default=1.0e-3)
- depth: int = attr.ib(default=7)
- out_channels: int = attr.ib(default=None, init=False)
- _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
- _blocks: nn.ModuleList = attr.ib(default=None, init=False)
- _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+ arch: str = field()
+ params: Tuple[float, float, float] = field(default=None, init=False)
+ stochastic_dropout_rate: float = field(default=0.2)
+ bn_momentum: float = field(default=0.99)
+ bn_eps: float = field(default=1.0e-3)
+ depth: int = field(default=7)
+ out_channels: int = field(default=None, init=False)
+ _conv_stem: nn.Sequential = field(default=None, init=False)
+ _blocks: nn.ModuleList = field(default=None, init=False)
+ _conv_head: nn.Sequential = field(default=None, init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self._build()
@depth.validator
- def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_depth(self, attribute, value: str) -> None:
if not 5 <= value <= 7:
raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
@arch.validator
- def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -88,7 +88,9 @@ class EfficientNet(nn.Module):
for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
- **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+ **args,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
)
)
args.in_channels = args.out_channels