summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/efficientnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/efficientnet.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py58
1 files changed, 32 insertions, 26 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index 6719efb..a36150a 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,7 @@
"""Efficient net."""
+from typing import Tuple
+
+import attr
from torch import nn, Tensor
from .mbconv import MBConvBlock
@@ -9,10 +12,13 @@ from .utils import (
)
+@attr.s
class EfficientNet(nn.Module):
- # TODO: attr
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
archs = {
- # width,depth0res,dropout
+ # width, depth, dropout
"b0": (1.0, 1.0, 0.2),
"b1": (1.0, 1.1, 0.2),
"b2": (1.1, 1.2, 0.3),
@@ -25,30 +31,30 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- def __init__(
- self,
- arch: str,
- out_channels: int = 1280,
- stochastic_dropout_rate: float = 0.2,
- bn_momentum: float = 0.99,
- bn_eps: float = 1.0e-3,
- ) -> None:
- super().__init__()
- assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
- self.arch = self.archs[arch]
- self.out_channels = out_channels
- self.stochastic_dropout_rate = stochastic_dropout_rate
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self._conv_stem: nn.Sequential = None
- self._blocks: nn.ModuleList = None
- self._conv_head: nn.Sequential = None
+ arch: str = attr.ib()
+ params: Tuple[float, float, float] = attr.ib(default=None, init=False)
+ out_channels: int = attr.ib(default=1280)
+ stochastic_dropout_rate: float = attr.ib(default=0.2)
+ bn_momentum: float = attr.ib(default=0.99)
+ bn_eps: float = attr.ib(default=1.0e-3)
+ _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
+ _blocks: nn.ModuleList = attr.ib(default=None, init=False)
+ _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self._build()
+ @arch.validator
+ def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ if value not in self.archs:
+ raise ValueError(f"{value} not a valid architecure.")
+ self.params = self.archs[value]
+
def _build(self) -> None:
_block_args = block_args()
in_channels = 1 # BW
- out_channels = round_filters(32, self.arch)
+ out_channels = round_filters(32, self.params)
self._conv_stem = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(
@@ -65,9 +71,9 @@ class EfficientNet(nn.Module):
)
self._blocks = nn.ModuleList([])
for args in _block_args:
- args.in_channels = round_filters(args.in_channels, self.arch)
- args.out_channels = round_filters(args.out_channels, self.arch)
- args.num_repeats = round_repeats(args.num_repeats, self.arch)
+ args.in_channels = round_filters(args.in_channels, self.params)
+ args.out_channels = round_filters(args.out_channels, self.params)
+ args.num_repeats = round_repeats(args.num_repeats, self.params)
for _ in range(args.num_repeats):
self._blocks.append(
MBConvBlock(
@@ -77,8 +83,8 @@ class EfficientNet(nn.Module):
args.in_channels = args.out_channels
args.stride = 1
- in_channels = round_filters(320, self.arch)
- out_channels = round_filters(self.out_channels, self.arch)
+ in_channels = round_filters(320, self.params)
+ out_channels = round_filters(self.out_channels, self.params)
self._conv_head = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(