summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py15
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py139
2 files changed, 67 insertions, 87 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a36150a..b8eb53b 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,4 @@
-"""Efficient net."""
+"""Efficientnet backbone."""
from typing import Tuple
import attr
@@ -12,8 +12,10 @@ from .utils import (
)
-@attr.s
+@attr.s(eq=False)
class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -47,11 +49,13 @@ class EfficientNet(nn.Module):
@arch.validator
def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
self.params = self.archs[value]
def _build(self) -> None:
+ """Builds the efficientnet backbone."""
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
@@ -73,8 +77,9 @@ class EfficientNet(nn.Module):
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.params)
args.out_channels = round_filters(args.out_channels, self.params)
- args.num_repeats = round_repeats(args.num_repeats, self.params)
- for _ in range(args.num_repeats):
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
**args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
@@ -93,6 +98,7 @@ class EfficientNet(nn.Module):
)
def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
x = self._conv_stem(x)
for i, block in enumerate(self._blocks):
stochastic_dropout_rate = self.stochastic_dropout_rate
@@ -103,4 +109,5 @@ class EfficientNet(nn.Module):
return x
def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
return self.extract_features(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 3aa63d0..e85df87 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,76 +1,62 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Union, Tuple
+from typing import Optional, Sequence, Union, Tuple
+import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from .utils import stochastic_depth
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (
+ (stride,) * 2 if isinstance(stride, int) else stride
+ )
+
+
+@attr.s(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- *args: Any,
- **kwargs: Any,
- ) -> None:
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.kernel_size = kernel_size
- self.stride = (stride,) * 2 if isinstance(stride, int) else stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.in_channels = in_channels
- self.out_channels = out_channels
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
if self.stride == (2, 2):
- self.pad = [
+ return (
(self.kernel_size - 1) // 2 - 1,
(self.kernel_size - 1) // 2,
- ] * 2
- else:
- self.pad = [(self.kernel_size - 1) // 2] * 4
-
- # Placeholders for layers.
- self._inverted_bottleneck: nn.Sequential = None
- self._depthwise: nn.Sequential = None
- self._squeeze_excite: nn.Sequential = None
- self._pointwise: nn.Sequential = None
-
- self._build(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- expand_ratio=expand_ratio,
- se_ratio=se_ratio,
- )
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
- def _build(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- expand_ratio: int,
- se_ratio: float,
- ) -> None:
- has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
- inner_channels = in_channels * expand_ratio
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(
- in_channels=in_channels, out_channels=inner_channels,
- )
- if expand_ratio != 1
+ self._configure_inverted_bottleneck(out_channels=inner_channels)
+ if self.expand_ratio != 1
else None
)
@@ -78,31 +64,23 @@ class MBConvBlock(nn.Module):
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
- kernel_size=kernel_size,
- stride=stride,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels,
- out_channels=inner_channels,
- se_ratio=se_ratio,
+ in_channels=inner_channels, out_channels=inner_channels,
)
if has_se
else None
)
- self._pointwise = self._configure_pointwise(
- in_channels=inner_channels, out_channels=out_channels
- )
+ self._pointwise = self._configure_pointwise(in_channels=inner_channels)
- def _configure_inverted_bottleneck(
- self, in_channels: int, out_channels: int,
- ) -> nn.Sequential:
+ def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
@@ -114,19 +92,14 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self,
- in_channels: int,
- out_channels: int,
- groups: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
+ self, in_channels: int, out_channels: int, groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
groups=groups,
bias=False,
),
@@ -137,9 +110,9 @@ class MBConvBlock(nn.Module):
)
def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int, se_ratio: float
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * se_ratio))
+ num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -154,18 +127,18 @@ class MBConvBlock(nn.Module):
),
)
- def _configure_pointwise(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
+ def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)
@@ -186,8 +159,8 @@ class MBConvBlock(nn.Module):
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None: