summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py32
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py71
2 files changed, 54 insertions, 49 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index 4c9ed75..cf64bcf 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,7 @@
"""Efficientnet backbone."""
from typing import Tuple
-import attr
+from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,7 +12,7 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@attr.s(eq=False)
+@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
@@ -33,28 +33,28 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = attr.ib()
- params: Tuple[float, float, float] = attr.ib(default=None, init=False)
- stochastic_dropout_rate: float = attr.ib(default=0.2)
- bn_momentum: float = attr.ib(default=0.99)
- bn_eps: float = attr.ib(default=1.0e-3)
- depth: int = attr.ib(default=7)
- out_channels: int = attr.ib(default=None, init=False)
- _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
- _blocks: nn.ModuleList = attr.ib(default=None, init=False)
- _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+ arch: str = field()
+ params: Tuple[float, float, float] = field(default=None, init=False)
+ stochastic_dropout_rate: float = field(default=0.2)
+ bn_momentum: float = field(default=0.99)
+ bn_eps: float = field(default=1.0e-3)
+ depth: int = field(default=7)
+ out_channels: int = field(default=None, init=False)
+ _conv_stem: nn.Sequential = field(default=None, init=False)
+ _blocks: nn.ModuleList = field(default=None, init=False)
+ _conv_head: nn.Sequential = field(default=None, init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
self._build()
@depth.validator
- def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_depth(self, attribute, value: str) -> None:
if not 5 <= value <= 7:
raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
@arch.validator
- def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ def _check_arch(self, attribute, value: str) -> None:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
@@ -88,7 +88,9 @@ class EfficientNet(nn.Module):
for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
- **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+ **args,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
)
)
args.in_channels = args.out_channels
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index beb7d57..98e9353 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -1,7 +1,7 @@
"""Mobile inverted residual block."""
from typing import Optional, Tuple, Union
-import attr
+from attrs import define, field
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -14,13 +14,13 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
return (stride,) * 2 if isinstance(stride, int) else stride
-@attr.s(eq=False)
+@define(eq=False)
class BaseModule(nn.Module):
"""Base sub module class."""
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- block: nn.Sequential = attr.ib(init=False)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ block: nn.Sequential = field(init=False)
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -36,12 +36,12 @@ class BaseModule(nn.Module):
return self.block(x)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class InvertedBottleneck(BaseModule):
"""Inverted bottleneck module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -60,13 +60,13 @@ class InvertedBottleneck(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Depthwise(BaseModule):
"""Depthwise convolution module."""
- channels: int = attr.ib()
- kernel_size: int = attr.ib()
- stride: int = attr.ib()
+ channels: int = field()
+ kernel_size: int = field()
+ stride: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -85,13 +85,13 @@ class Depthwise(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class SqueezeAndExcite(BaseModule):
"""Sequeeze and excite module."""
- in_channels: int = attr.ib()
- channels: int = attr.ib()
- se_ratio: float = attr.ib()
+ in_channels: int = field()
+ channels: int = field()
+ se_ratio: float = field()
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -110,12 +110,12 @@ class SqueezeAndExcite(BaseModule):
)
-@attr.s(auto_attribs=True, eq=False)
+@define(auto_attribs=True, eq=False)
class Pointwise(BaseModule):
"""Pointwise module."""
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
+ in_channels: int = field()
+ out_channels: int = field()
def _build(self) -> None:
self.block = nn.Sequential(
@@ -133,32 +133,35 @@ class Pointwise(BaseModule):
)
-@attr.s(eq=False)
+@define(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
def __attrs_pre_init__(self) -> None:
super().__init__()
- in_channels: int = attr.ib()
- out_channels: int = attr.ib()
- kernel_size: Tuple[int, int] = attr.ib()
- stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
- bn_momentum: float = attr.ib()
- bn_eps: float = attr.ib()
- se_ratio: float = attr.ib()
- expand_ratio: int = attr.ib()
- pad: Tuple[int, int, int, int] = attr.ib(init=False)
- _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
- _depthwise: nn.Sequential = attr.ib(init=False)
- _squeeze_excite: nn.Sequential = attr.ib(init=False)
- _pointwise: nn.Sequential = attr.ib(init=False)
+ in_channels: int = field()
+ out_channels: int = field()
+ kernel_size: Tuple[int, int] = field()
+ stride: Tuple[int, int] = field(converter=_convert_stride)
+ bn_momentum: float = field()
+ bn_eps: float = field()
+ se_ratio: float = field()
+ expand_ratio: int = field()
+ pad: Tuple[int, int, int, int] = field(init=False)
+ _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False)
+ _depthwise: nn.Sequential = field(init=False)
+ _squeeze_excite: nn.Sequential = field(init=False)
+ _pointwise: nn.Sequential = field(init=False)
@pad.default
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
- return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+ return (
+ (self.kernel_size - 1) // 2 - 1,
+ (self.kernel_size - 1) // 2,
+ ) * 2
return ((self.kernel_size - 1) // 2,) * 4
def __attrs_post_init__(self) -> None: