summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 23:39:11 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-05 23:39:11 +0200
commit65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch)
treed78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/networks
parent8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff)
Remove attrs
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py47
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py121
2 files changed, 99 insertions, 69 deletions
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
index cf64bcf..2260ee2 100644
--- a/text_recognizer/networks/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -1,7 +1,6 @@
"""Efficientnet backbone."""
from typing import Tuple
-from attrs import define, field
from torch import nn, Tensor
from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
@@ -12,13 +11,9 @@ from text_recognizer.networks.efficientnet.utils import (
)
-@define(eq=False)
class EfficientNet(nn.Module):
"""Efficientnet without classification head."""
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
archs = {
# width, depth, dropout
"b0": (1.0, 1.0, 0.2),
@@ -33,32 +28,32 @@ class EfficientNet(nn.Module):
"l2": (4.3, 5.3, 0.5),
}
- arch: str = field()
- params: Tuple[float, float, float] = field(default=None, init=False)
- stochastic_dropout_rate: float = field(default=0.2)
- bn_momentum: float = field(default=0.99)
- bn_eps: float = field(default=1.0e-3)
- depth: int = field(default=7)
- out_channels: int = field(default=None, init=False)
- _conv_stem: nn.Sequential = field(default=None, init=False)
- _blocks: nn.ModuleList = field(default=None, init=False)
- _conv_head: nn.Sequential = field(default=None, init=False)
-
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
+ def __init__(
+ self,
+ arch: str,
+ params: Tuple[float, float, float],
+ stochastic_dropout_rate: float = 0.2,
+ bn_momentum: float = 0.99,
+ bn_eps: float = 1.0e-3,
+ depth: int = 7,
+ ) -> None:
+ super().__init__()
+ self.params = self._get_arch_params(arch)
+ self.stochastic_dropout_rate = stochastic_dropout_rate
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.depth = depth
+ self.out_channels: int
+ self._conv_stem: nn.Sequential
+ self._blocks: nn.ModuleList
+ self._conv_head: nn.Sequential
self._build()
- @depth.validator
- def _check_depth(self, attribute, value: str) -> None:
- if not 5 <= value <= 7:
- raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
-
- @arch.validator
- def _check_arch(self, attribute, value: str) -> None:
+ def _get_arch_params(self, value: str) -> Tuple[float, float, float]:
"""Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
- self.params = self.archs[value]
+ return self.archs[value]
def _build(self) -> None:
"""Builds the efficientnet backbone."""
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
index 98e9353..64debd9 100644
--- a/text_recognizer/networks/efficientnet/mbconv.py
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -1,7 +1,6 @@
"""Mobile inverted residual block."""
from typing import Optional, Tuple, Union
-from attrs import define, field
import torch
from torch import nn, Tensor
import torch.nn.functional as F
@@ -14,18 +13,15 @@ def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
return (stride,) * 2 if isinstance(stride, int) else stride
-@define(eq=False)
class BaseModule(nn.Module):
"""Base sub module class."""
- bn_momentum: float = field()
- bn_eps: float = field()
- block: nn.Sequential = field(init=False)
-
- def __attrs_pre_init__(self) -> None:
+ def __init__(self, bn_momentum: float, bn_eps: float, block: nn.Sequential) -> None:
super().__init__()
- def __attrs_post_init__(self) -> None:
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.block = block
self._build()
def _build(self) -> None:
@@ -36,12 +32,20 @@ class BaseModule(nn.Module):
return self.block(x)
-@define(auto_attribs=True, eq=False)
class InvertedBottleneck(BaseModule):
"""Inverted bottleneck module."""
- in_channels: int = field()
- out_channels: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ out_channels: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
def _build(self) -> None:
self.block = nn.Sequential(
@@ -60,13 +64,22 @@ class InvertedBottleneck(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class Depthwise(BaseModule):
"""Depthwise convolution module."""
- channels: int = field()
- kernel_size: int = field()
- stride: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ channels: int,
+ kernel_size: int,
+ stride: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.stride = stride
def _build(self) -> None:
self.block = nn.Sequential(
@@ -85,13 +98,23 @@ class Depthwise(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class SqueezeAndExcite(BaseModule):
"""Sequeeze and excite module."""
- in_channels: int = field()
- channels: int = field()
- se_ratio: float = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ channels: int,
+ se_ratio: float,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+
+ self.in_channels = in_channels
+ self.channels = channels
+ self.se_ratio = se_ratio
def _build(self) -> None:
num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
@@ -110,12 +133,20 @@ class SqueezeAndExcite(BaseModule):
)
-@define(auto_attribs=True, eq=False)
class Pointwise(BaseModule):
"""Pointwise module."""
- in_channels: int = field()
- out_channels: int = field()
+ def __init__(
+ self,
+ bn_momentum: float,
+ bn_eps: float,
+ block: nn.Sequential,
+ in_channels: int,
+ out_channels: int,
+ ) -> None:
+ super().__init__(bn_momentum, bn_eps, block)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
def _build(self) -> None:
self.block = nn.Sequential(
@@ -133,28 +164,36 @@ class Pointwise(BaseModule):
)
-@define(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __attrs_pre_init__(self) -> None:
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Tuple[int, int],
+ stride: Tuple[int, int],
+ bn_momentum: float,
+ bn_eps: float,
+ se_ratio: float,
+ expand_ratio: int,
+ ) -> None:
super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.bn_momentum = bn_momentum
+ self.bn_eps = bn_eps
+ self.se_ratio = se_ratio
+ self.expand_ratio = expand_ratio
+ self.pad = self._configure_padding()
+ self._inverted_bottleneck: Optional[InvertedBottleneck]
+ self._depthwise: nn.Sequential
+ self._squeeze_excite: nn.Sequential
+ self._pointwise: nn.Sequential
+ self._build()
- in_channels: int = field()
- out_channels: int = field()
- kernel_size: Tuple[int, int] = field()
- stride: Tuple[int, int] = field(converter=_convert_stride)
- bn_momentum: float = field()
- bn_eps: float = field()
- se_ratio: float = field()
- expand_ratio: int = field()
- pad: Tuple[int, int, int, int] = field(init=False)
- _inverted_bottleneck: Optional[InvertedBottleneck] = field(init=False)
- _depthwise: nn.Sequential = field(init=False)
- _squeeze_excite: nn.Sequential = field(init=False)
- _pointwise: nn.Sequential = field(init=False)
-
- @pad.default
def _configure_padding(self) -> Tuple[int, int, int, int]:
"""Set padding for convolutional layers."""
if self.stride == (2, 2):
@@ -164,10 +203,6 @@ class MBConvBlock(nn.Module):
) * 2
return ((self.kernel_size - 1) // 2,) * 4
- def __attrs_post_init__(self) -> None:
- """Post init configuration."""
- self._build()
-
def _build(self) -> None:
has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
inner_channels = self.in_channels * self.expand_ratio