From 050e1bd284a173d2586ad4607e95d114691db563 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 22 Nov 2021 22:38:43 +0100
Subject: Move efficientnet from encoder dir

---
 text_recognizer/networks/efficientnet/__init__.py  |   2 +
 .../networks/efficientnet/efficientnet.py          | 124 +++++++++++
 text_recognizer/networks/efficientnet/mbconv.py    | 240 +++++++++++++++++++++
 text_recognizer/networks/efficientnet/utils.py     |  86 ++++++++
 4 files changed, 452 insertions(+)
 create mode 100644 text_recognizer/networks/efficientnet/__init__.py
 create mode 100644 text_recognizer/networks/efficientnet/efficientnet.py
 create mode 100644 text_recognizer/networks/efficientnet/mbconv.py
 create mode 100644 text_recognizer/networks/efficientnet/utils.py

(limited to 'text_recognizer/networks/efficientnet')

diff --git a/text_recognizer/networks/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py
new file mode 100644
index 0000000..344233f
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/__init__.py
@@ -0,0 +1,2 @@
+"""Efficient net."""
+from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
new file mode 100644
index 0000000..4c9ed75
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -0,0 +1,124 @@
+"""Efficientnet backbone."""
+from typing import Tuple
+
+import attr
+from torch import nn, Tensor
+
+from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
+from text_recognizer.networks.efficientnet.utils import (
+    block_args,
+    round_filters,
+    round_repeats,
+)
+
+
+@attr.s(eq=False)
+class EfficientNet(nn.Module):
+    """Efficientnet without classification head."""
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    archs = {
+        # width, depth, dropout
+        "b0": (1.0, 1.0, 0.2),
+        "b1": (1.0, 1.1, 0.2),
+        "b2": (1.1, 1.2, 0.3),
+        "b3": (1.2, 1.4, 0.3),
+        "b4": (1.4, 1.8, 0.4),
+        "b5": (1.6, 2.2, 0.4),
+        "b6": (1.8, 2.6, 0.5),
+        "b7": (2.0, 3.1, 0.5),
+        "b8": (2.2, 3.6, 0.5),
+        "l2": (4.3, 5.3, 0.5),
+    }
+
+    arch: str = attr.ib()
+    params: Tuple[float, float, float] = attr.ib(default=None, init=False)
+    stochastic_dropout_rate: float = attr.ib(default=0.2)
+    bn_momentum: float = attr.ib(default=0.99)
+    bn_eps: float = attr.ib(default=1.0e-3)
+    depth: int = attr.ib(default=7)
+    out_channels: int = attr.ib(default=None, init=False)
+    _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
+    _blocks: nn.ModuleList = attr.ib(default=None, init=False)
+    _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+
+    def __attrs_post_init__(self) -> None:
+        """Post init configuration."""
+        self._build()
+
+    @depth.validator
+    def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+        if not 5 <= value <= 7:
+            raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
+
+    @arch.validator
+    def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+        """Validates the efficientnet architecure."""
+        if value not in self.archs:
+            raise ValueError(f"{value} not a valid architecure.")
+        self.params = self.archs[value]
+
+    def _build(self) -> None:
+        """Builds the efficientnet backbone."""
+        _block_args = block_args()[: self.depth]
+        in_channels = 1  # BW
+        out_channels = round_filters(32, self.params)
+        self._conv_stem = nn.Sequential(
+            nn.ZeroPad2d((0, 1, 0, 1)),
+            nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=(2, 2),
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
+            nn.Mish(inplace=True),
+        )
+        self._blocks = nn.ModuleList([])
+        for args in _block_args:
+            args.in_channels = round_filters(args.in_channels, self.params)
+            args.out_channels = round_filters(args.out_channels, self.params)
+            num_repeats = round_repeats(args.num_repeats, self.params)
+            del args.num_repeats
+            for _ in range(num_repeats):
+                self._blocks.append(
+                    MBConvBlock(
+                        **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+                    )
+                )
+                args.in_channels = args.out_channels
+                args.stride = 1
+
+        in_channels = round_filters(_block_args[-1].out_channels, self.params)
+        self.out_channels = round_filters(1280, self.params)
+        self._conv_head = nn.Sequential(
+            nn.Conv2d(
+                in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
+            ),
+            nn.BatchNorm2d(
+                num_features=self.out_channels,
+                momentum=self.bn_momentum,
+                eps=self.bn_eps,
+            ),
+            nn.Dropout(p=self.params[-1]),
+        )
+
+    def extract_features(self, x: Tensor) -> Tensor:
+        """Extracts the final feature map layer."""
+        x = self._conv_stem(x)
+        for i, block in enumerate(self._blocks):
+            stochastic_dropout_rate = self.stochastic_dropout_rate
+            if self.stochastic_dropout_rate:
+                stochastic_dropout_rate *= i / len(self._blocks)
+            x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
+        x = self._conv_head(x)
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Returns efficientnet image features."""
+        return self.extract_features(x)
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
new file mode 100644
index 0000000..4b051eb
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -0,0 +1,240 @@
+"""Mobile inverted residual block."""
+from typing import Optional, Tuple, Union
+
+import attr
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+
+
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+    """Converts int to tuple."""
+    return (stride,) * 2 if isinstance(stride, int) else stride
+
+
+@attr.s(eq=False)
+class BaseModule(nn.Module):
+    """Base sub module class."""
+
+    bn_momentum: float = attr.ib()
+    bn_eps: float = attr.ib()
+    block: nn.Sequential = attr.ib(init=False)
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    def __attrs_post_init__(self) -> None:
+        self._build()
+
+    def _build(self) -> None:
+        pass
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Forward pass."""
+        return self.block(x)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class InvertedBottleneck(BaseModule):
+    """Inverted bottleneck module."""
+
+    in_channels: int = attr.ib()
+    out_channels: int = attr.ib()
+
+    def _build(self) -> None:
+        self.block = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.in_channels,
+                out_channels=self.out_channels,
+                kernel_size=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=self.out_channels,
+                momentum=self.bn_momentum,
+                eps=self.bn_eps,
+            ),
+            nn.Mish(inplace=True),
+        )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class Depthwise(BaseModule):
+    """Depthwise convolution module."""
+
+    channels: int = attr.ib()
+    kernel_size: int = attr.ib()
+    stride: int = attr.ib()
+
+    def _build(self) -> None:
+        self.block = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.channels,
+                out_channels=self.channels,
+                kernel_size=self.kernel_size,
+                stride=self.stride,
+                groups=self.channels,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps
+            ),
+            nn.Mish(inplace=True),
+        )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class SqueezeAndExcite(BaseModule):
+    """Sequeeze and excite module."""
+
+    in_channels: int = attr.ib()
+    channels: int = attr.ib()
+    se_ratio: float = attr.ib()
+
+    def _build(self) -> None:
+        num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
+        self.block = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.channels,
+                out_channels=num_squeezed_channels,
+                kernel_size=1,
+            ),
+            nn.Mish(inplace=True),
+            nn.Conv2d(
+                in_channels=num_squeezed_channels,
+                out_channels=self.channels,
+                kernel_size=1,
+            ),
+        )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class Pointwise(BaseModule):
+    """Pointwise module."""
+
+    in_channels: int = attr.ib()
+    out_channels: int = attr.ib()
+
+    def _build(self) -> None:
+        self.block = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.in_channels,
+                out_channels=self.out_channels,
+                kernel_size=1,
+                bias=False,
+            ),
+            nn.BatchNorm2d(
+                num_features=self.out_channels,
+                momentum=self.bn_momentum,
+                eps=self.bn_eps,
+            ),
+        )
+
+
+@attr.s(eq=False)
+class MBConvBlock(nn.Module):
+    """Mobile Inverted Residual Bottleneck block."""
+
+    def __attrs_pre_init__(self) -> None:
+        super().__init__()
+
+    in_channels: int = attr.ib()
+    out_channels: int = attr.ib()
+    kernel_size: Tuple[int, int] = attr.ib()
+    stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+    bn_momentum: float = attr.ib()
+    bn_eps: float = attr.ib()
+    se_ratio: float = attr.ib()
+    expand_ratio: int = attr.ib()
+    pad: Tuple[int, int, int, int] = attr.ib(init=False)
+    _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
+    _depthwise: nn.Sequential = attr.ib(init=False)
+    _squeeze_excite: nn.Sequential = attr.ib(init=False)
+    _pointwise: nn.Sequential = attr.ib(init=False)
+
+    @pad.default
+    def _configure_padding(self) -> Tuple[int, int, int, int]:
+        """Set padding for convolutional layers."""
+        if self.stride == (2, 2):
+            return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+        return ((self.kernel_size - 1) // 2,) * 4
+
+    def __attrs_post_init__(self) -> None:
+        """Post init configuration."""
+        self._build()
+
+    def _build(self) -> None:
+        has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+        inner_channels = self.in_channels * self.expand_ratio
+        self._inverted_bottleneck = (
+            InvertedBottleneck(
+                in_channels=self.in_channels,
+                out_channels=inner_channels,
+                bn_momentum=self.bn_momentum,
+                bn_eps=self.bn_eps,
+            )
+            if self.expand_ratio != 1
+            else None
+        )
+
+        self._depthwise = Depthwise(
+            channels=inner_channels,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            bn_momentum=self.bn_momentum,
+            bn_eps=self.bn_eps,
+        )
+
+        self._squeeze_excite = (
+            SqueezeAndExcite(
+                in_channels=self.in_channels,
+                channels=inner_channels,
+                se_ratio=self.se_ratio,
+                bn_momentum=self.bn_momentum,
+                bn_eps=self.bn_eps,
+            )
+            if has_se
+            else None
+        )
+
+        self._pointwise = Pointwise(
+            in_channels=inner_channels,
+            out_channels=self.out_channels,
+            bn_momentum=self.bn_momentum,
+            bn_eps=self.bn_eps,
+        )
+
+    def _stochastic_depth(
+        self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
+    ) -> Tensor:
+        if self.stride == (1, 1) and self.in_channels == self.out_channels:
+            if stochastic_dropout_rate:
+                x = stochastic_depth(
+                    x, p=stochastic_dropout_rate, training=self.training
+                )
+            x += residual
+        return x
+
+    def forward(
+        self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
+    ) -> Tensor:
+        """Forward pass."""
+        residual = x
+        if self._inverted_bottleneck is not None:
+            x = self._inverted_bottleneck(x)
+
+        x = F.pad(x, self.pad)
+        x = self._depthwise(x)
+
+        if self._squeeze_excite is not None:
+            x_squeezed = F.adaptive_avg_pool2d(x, 1)
+            x_squeezed = self._squeeze_excite(x)
+            x = torch.tanh(F.softplus(x_squeezed)) * x
+
+        x = self._pointwise(x)
+
+        # Stochastic depth
+        x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
+        return x
diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
new file mode 100644
index 0000000..5234324
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -0,0 +1,86 @@
+"""Util functions for efficient net."""
+import math
+from typing import List, Tuple
+
+from omegaconf import DictConfig, OmegaConf
+import torch
+from torch import Tensor
+
+
+def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
+    """Stochastic connection.
+
+    Drops the entire convolution with a given survival probability.
+
+    Args:
+        x (Tensor): Input tensor.
+        p (float): Survival probability between 0.0 and 1.0.
+        training (bool): The running mode.
+
+    Shapes:
+        - x: :math: `(B, C, W, H)`.
+        - out: :math: `(B, C, W, H)`.
+
+        where B is the batch size, C is the number of channels, W is the width, and H
+        is the height.
+
+    Returns:
+        out (Tensor): Output after drop connection.
+    """
+    assert 0.0 <= p <= 1.0, "p must be in range of [0, 1]"
+
+    if not training:
+        return x
+
+    bsz = x.shape[0]
+    survival_prob = 1 - p
+
+    # Generate a binary tensor mask according to probability (p for 0, 1-p for 1)
+    random_tensor = survival_prob
+    random_tensor += torch.rand([bsz, 1, 1, 1]).type_as(x)
+    binary_tensor = torch.floor(random_tensor)
+
+    out = x / survival_prob * binary_tensor
+    return out
+
+
+def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+    """Returns the number output filters for a block."""
+    multiplier = arch[0]
+    divisor = 8
+    filters *= multiplier
+    new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor)
+    if new_filters < 0.9 * filters:
+        new_filters += divisor
+    return int(new_filters)
+
+
+def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+    """Returns how many times a layer should be repeated in a block."""
+    return int(math.ceil(arch[1] * repeats))
+
+
+def block_args() -> List[DictConfig]:
+    """Returns arguments for each efficientnet block."""
+    keys = [
+        "num_repeats",
+        "kernel_size",
+        "stride",
+        "expand_ratio",
+        "in_channels",
+        "out_channels",
+        "se_ratio",
+    ]
+    args = [
+        [1, 3, (1, 1), 1, 32, 16, 0.25],
+        [2, 3, (2, 2), 6, 16, 24, 0.25],
+        [2, 5, (2, 2), 6, 24, 40, 0.25],
+        [3, 3, (2, 2), 6, 40, 80, 0.25],
+        [3, 5, (1, 1), 6, 80, 112, 0.25],
+        [4, 5, (2, 2), 6, 112, 192, 0.25],
+        [1, 3, (1, 1), 6, 192, 320, 0.25],
+    ]
+    block_args_ = []
+    for row in args:
+        block_args_.append(OmegaConf.create(dict(zip(keys, row))))
+    return block_args_
-- 
cgit v1.2.3-70-g09d2