summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/efficientnet
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r--text_recognizer/networks/efficientnet/__init__.py2
-rw-r--r--text_recognizer/networks/efficientnet/efficientnet.py124
-rw-r--r--text_recognizer/networks/efficientnet/mbconv.py240
-rw-r--r--text_recognizer/networks/efficientnet/utils.py86
4 files changed, 452 insertions, 0 deletions
diff --git a/text_recognizer/networks/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py
new file mode 100644
index 0000000..344233f
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/__init__.py
@@ -0,0 +1,2 @@
+"""Efficient net."""
+from .efficientnet import EfficientNet
diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py
new file mode 100644
index 0000000..4c9ed75
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/efficientnet.py
@@ -0,0 +1,124 @@
+"""Efficientnet backbone."""
+from typing import Tuple
+
+import attr
+from torch import nn, Tensor
+
+from text_recognizer.networks.efficientnet.mbconv import MBConvBlock
+from text_recognizer.networks.efficientnet.utils import (
+ block_args,
+ round_filters,
+ round_repeats,
+)
+
+
+@attr.s(eq=False)
+class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ archs = {
+ # width, depth, dropout
+ "b0": (1.0, 1.0, 0.2),
+ "b1": (1.0, 1.1, 0.2),
+ "b2": (1.1, 1.2, 0.3),
+ "b3": (1.2, 1.4, 0.3),
+ "b4": (1.4, 1.8, 0.4),
+ "b5": (1.6, 2.2, 0.4),
+ "b6": (1.8, 2.6, 0.5),
+ "b7": (2.0, 3.1, 0.5),
+ "b8": (2.2, 3.6, 0.5),
+ "l2": (4.3, 5.3, 0.5),
+ }
+
+ arch: str = attr.ib()
+ params: Tuple[float, float, float] = attr.ib(default=None, init=False)
+ stochastic_dropout_rate: float = attr.ib(default=0.2)
+ bn_momentum: float = attr.ib(default=0.99)
+ bn_eps: float = attr.ib(default=1.0e-3)
+ depth: int = attr.ib(default=7)
+ out_channels: int = attr.ib(default=None, init=False)
+ _conv_stem: nn.Sequential = attr.ib(default=None, init=False)
+ _blocks: nn.ModuleList = attr.ib(default=None, init=False)
+ _conv_head: nn.Sequential = attr.ib(default=None, init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
+
+ @depth.validator
+ def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None:
+ if not 5 <= value <= 7:
+ raise ValueError(f"Depth has to be between 5 and 7, was: {value}")
+
+ @arch.validator
+ def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
+ if value not in self.archs:
+ raise ValueError(f"{value} not a valid architecure.")
+ self.params = self.archs[value]
+
+ def _build(self) -> None:
+ """Builds the efficientnet backbone."""
+ _block_args = block_args()[: self.depth]
+ in_channels = 1 # BW
+ out_channels = round_filters(32, self.params)
+ self._conv_stem = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=(2, 2),
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ ),
+ nn.Mish(inplace=True),
+ )
+ self._blocks = nn.ModuleList([])
+ for args in _block_args:
+ args.in_channels = round_filters(args.in_channels, self.params)
+ args.out_channels = round_filters(args.out_channels, self.params)
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
+ self._blocks.append(
+ MBConvBlock(
+ **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
+ )
+ )
+ args.in_channels = args.out_channels
+ args.stride = 1
+
+ in_channels = round_filters(_block_args[-1].out_channels, self.params)
+ self.out_channels = round_filters(1280, self.params)
+ self._conv_head = nn.Sequential(
+ nn.Conv2d(
+ in_channels, self.out_channels, kernel_size=1, stride=1, bias=False
+ ),
+ nn.BatchNorm2d(
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
+ ),
+ nn.Dropout(p=self.params[-1]),
+ )
+
+ def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
+ x = self._conv_stem(x)
+ for i, block in enumerate(self._blocks):
+ stochastic_dropout_rate = self.stochastic_dropout_rate
+ if self.stochastic_dropout_rate:
+ stochastic_dropout_rate *= i / len(self._blocks)
+ x = block(x, stochastic_dropout_rate=stochastic_dropout_rate)
+ x = self._conv_head(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
+ return self.extract_features(x)
diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py
new file mode 100644
index 0000000..4b051eb
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/mbconv.py
@@ -0,0 +1,240 @@
+"""Mobile inverted residual block."""
+from typing import Optional, Tuple, Union
+
+import attr
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+
+
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (stride,) * 2 if isinstance(stride, int) else stride
+
+
+@attr.s(eq=False)
+class BaseModule(nn.Module):
+ """Base sub module class."""
+
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ block: nn.Sequential = attr.ib(init=False)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ self._build()
+
+ def _build(self) -> None:
+ pass
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ return self.block(x)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class InvertedBottleneck(BaseModule):
+ """Inverted bottleneck module."""
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+
+ def _build(self) -> None:
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
+ ),
+ nn.Mish(inplace=True),
+ )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class Depthwise(BaseModule):
+ """Depthwise convolution module."""
+
+ channels: int = attr.ib()
+ kernel_size: int = attr.ib()
+ stride: int = attr.ib()
+
+ def _build(self) -> None:
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ groups=self.channels,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps
+ ),
+ nn.Mish(inplace=True),
+ )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class SqueezeAndExcite(BaseModule):
+ """Sequeeze and excite module."""
+
+ in_channels: int = attr.ib()
+ channels: int = attr.ib()
+ se_ratio: float = attr.ib()
+
+ def _build(self) -> None:
+ num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio))
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.channels,
+ out_channels=num_squeezed_channels,
+ kernel_size=1,
+ ),
+ nn.Mish(inplace=True),
+ nn.Conv2d(
+ in_channels=num_squeezed_channels,
+ out_channels=self.channels,
+ kernel_size=1,
+ ),
+ )
+
+
+@attr.s(auto_attribs=True, eq=False)
+class Pointwise(BaseModule):
+ """Pointwise module."""
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+
+ def _build(self) -> None:
+ self.block = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
+ ),
+ )
+
+
+@attr.s(eq=False)
+class MBConvBlock(nn.Module):
+ """Mobile Inverted Residual Bottleneck block."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
+ if self.stride == (2, 2):
+ return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
+
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
+ self._inverted_bottleneck = (
+ InvertedBottleneck(
+ in_channels=self.in_channels,
+ out_channels=inner_channels,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+ if self.expand_ratio != 1
+ else None
+ )
+
+ self._depthwise = Depthwise(
+ channels=inner_channels,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+
+ self._squeeze_excite = (
+ SqueezeAndExcite(
+ in_channels=self.in_channels,
+ channels=inner_channels,
+ se_ratio=self.se_ratio,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+ if has_se
+ else None
+ )
+
+ self._pointwise = Pointwise(
+ in_channels=inner_channels,
+ out_channels=self.out_channels,
+ bn_momentum=self.bn_momentum,
+ bn_eps=self.bn_eps,
+ )
+
+ def _stochastic_depth(
+ self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float]
+ ) -> Tensor:
+ if self.stride == (1, 1) and self.in_channels == self.out_channels:
+ if stochastic_dropout_rate:
+ x = stochastic_depth(
+ x, p=stochastic_dropout_rate, training=self.training
+ )
+ x += residual
+ return x
+
+ def forward(
+ self, x: Tensor, stochastic_dropout_rate: Optional[float] = None
+ ) -> Tensor:
+ """Forward pass."""
+ residual = x
+ if self._inverted_bottleneck is not None:
+ x = self._inverted_bottleneck(x)
+
+ x = F.pad(x, self.pad)
+ x = self._depthwise(x)
+
+ if self._squeeze_excite is not None:
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
+ x_squeezed = self._squeeze_excite(x)
+ x = torch.tanh(F.softplus(x_squeezed)) * x
+
+ x = self._pointwise(x)
+
+ # Stochastic depth
+ x = self._stochastic_depth(x, residual, stochastic_dropout_rate)
+ return x
diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py
new file mode 100644
index 0000000..5234324
--- /dev/null
+++ b/text_recognizer/networks/efficientnet/utils.py
@@ -0,0 +1,86 @@
+"""Util functions for efficient net."""
+import math
+from typing import List, Tuple
+
+from omegaconf import DictConfig, OmegaConf
+import torch
+from torch import Tensor
+
+
+def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
+ """Stochastic connection.
+
+ Drops the entire convolution with a given survival probability.
+
+ Args:
+ x (Tensor): Input tensor.
+ p (float): Survival probability between 0.0 and 1.0.
+ training (bool): The running mode.
+
+ Shapes:
+ - x: :math: `(B, C, W, H)`.
+ - out: :math: `(B, C, W, H)`.
+
+ where B is the batch size, C is the number of channels, W is the width, and H
+ is the height.
+
+ Returns:
+ out (Tensor): Output after drop connection.
+ """
+ assert 0.0 <= p <= 1.0, "p must be in range of [0, 1]"
+
+ if not training:
+ return x
+
+ bsz = x.shape[0]
+ survival_prob = 1 - p
+
+ # Generate a binary tensor mask according to probability (p for 0, 1-p for 1)
+ random_tensor = survival_prob
+ random_tensor += torch.rand([bsz, 1, 1, 1]).type_as(x)
+ binary_tensor = torch.floor(random_tensor)
+
+ out = x / survival_prob * binary_tensor
+ return out
+
+
+def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+ """Returns the number output filters for a block."""
+ multiplier = arch[0]
+ divisor = 8
+ filters *= multiplier
+ new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor)
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+ """Returns how many times a layer should be repeated in a block."""
+ return int(math.ceil(arch[1] * repeats))
+
+
+def block_args() -> List[DictConfig]:
+ """Returns arguments for each efficientnet block."""
+ keys = [
+ "num_repeats",
+ "kernel_size",
+ "stride",
+ "expand_ratio",
+ "in_channels",
+ "out_channels",
+ "se_ratio",
+ ]
+ args = [
+ [1, 3, (1, 1), 1, 32, 16, 0.25],
+ [2, 3, (2, 2), 6, 16, 24, 0.25],
+ [2, 5, (2, 2), 6, 24, 40, 0.25],
+ [3, 3, (2, 2), 6, 40, 80, 0.25],
+ [3, 5, (1, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
+ ]
+ block_args_ = []
+ for row in args:
+ block_args_.append(OmegaConf.create(dict(zip(keys, row))))
+ return block_args_