diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:43 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-22 22:38:43 +0100 |
commit | 050e1bd284a173d2586ad4607e95d114691db563 (patch) | |
tree | f428b5a17396a7cd585e89d84765e3d7ed233618 /text_recognizer/networks/efficientnet | |
parent | 36875ceca1e00f5bb39a151c50ecf5e333b4cf79 (diff) |
Move efficientnet from encoder dir
Diffstat (limited to 'text_recognizer/networks/efficientnet')
-rw-r--r-- | text_recognizer/networks/efficientnet/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/efficientnet.py | 124 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/mbconv.py | 240 | ||||
-rw-r--r-- | text_recognizer/networks/efficientnet/utils.py | 86 |
4 files changed, 452 insertions, 0 deletions
diff --git a/text_recognizer/networks/efficientnet/__init__.py b/text_recognizer/networks/efficientnet/__init__.py new file mode 100644 index 0000000..344233f --- /dev/null +++ b/text_recognizer/networks/efficientnet/__init__.py @@ -0,0 +1,2 @@ +"""Efficient net.""" +from .efficientnet import EfficientNet diff --git a/text_recognizer/networks/efficientnet/efficientnet.py b/text_recognizer/networks/efficientnet/efficientnet.py new file mode 100644 index 0000000..4c9ed75 --- /dev/null +++ b/text_recognizer/networks/efficientnet/efficientnet.py @@ -0,0 +1,124 @@ +"""Efficientnet backbone.""" +from typing import Tuple + +import attr +from torch import nn, Tensor + +from text_recognizer.networks.efficientnet.mbconv import MBConvBlock +from text_recognizer.networks.efficientnet.utils import ( + block_args, + round_filters, + round_repeats, +) + + +@attr.s(eq=False) +class EfficientNet(nn.Module): + """Efficientnet without classification head.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + archs = { + # width, depth, dropout + "b0": (1.0, 1.0, 0.2), + "b1": (1.0, 1.1, 0.2), + "b2": (1.1, 1.2, 0.3), + "b3": (1.2, 1.4, 0.3), + "b4": (1.4, 1.8, 0.4), + "b5": (1.6, 2.2, 0.4), + "b6": (1.8, 2.6, 0.5), + "b7": (2.0, 3.1, 0.5), + "b8": (2.2, 3.6, 0.5), + "l2": (4.3, 5.3, 0.5), + } + + arch: str = attr.ib() + params: Tuple[float, float, float] = attr.ib(default=None, init=False) + stochastic_dropout_rate: float = attr.ib(default=0.2) + bn_momentum: float = attr.ib(default=0.99) + bn_eps: float = attr.ib(default=1.0e-3) + depth: int = attr.ib(default=7) + out_channels: int = attr.ib(default=None, init=False) + _conv_stem: nn.Sequential = attr.ib(default=None, init=False) + _blocks: nn.ModuleList = attr.ib(default=None, init=False) + _conv_head: nn.Sequential = attr.ib(default=None, init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() + + @depth.validator + def _check_depth(self, attribute: attr._make.Attribute, value: str) -> None: + if not 5 <= value <= 7: + raise ValueError(f"Depth has to be between 5 and 7, was: {value}") + + @arch.validator + def _check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + """Validates the efficientnet architecure.""" + if value not in self.archs: + raise ValueError(f"{value} not a valid architecure.") + self.params = self.archs[value] + + def _build(self) -> None: + """Builds the efficientnet backbone.""" + _block_args = block_args()[: self.depth] + in_channels = 1 # BW + out_channels = round_filters(32, self.params) + self._conv_stem = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2), + bias=False, + ), + nn.BatchNorm2d( + num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + ), + nn.Mish(inplace=True), + ) + self._blocks = nn.ModuleList([]) + for args in _block_args: + args.in_channels = round_filters(args.in_channels, self.params) + args.out_channels = round_filters(args.out_channels, self.params) + num_repeats = round_repeats(args.num_repeats, self.params) + del args.num_repeats + for _ in range(num_repeats): + self._blocks.append( + MBConvBlock( + **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, + ) + ) + args.in_channels = args.out_channels + args.stride = 1 + + in_channels = round_filters(_block_args[-1].out_channels, self.params) + self.out_channels = round_filters(1280, self.params) + self._conv_head = nn.Sequential( + nn.Conv2d( + in_channels, self.out_channels, kernel_size=1, stride=1, bias=False + ), + nn.BatchNorm2d( + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, + ), + nn.Dropout(p=self.params[-1]), + ) + + def extract_features(self, x: Tensor) -> Tensor: + """Extracts the final feature map layer.""" + x = self._conv_stem(x) + for i, block in enumerate(self._blocks): + stochastic_dropout_rate = self.stochastic_dropout_rate + if self.stochastic_dropout_rate: + stochastic_dropout_rate *= i / len(self._blocks) + x = block(x, stochastic_dropout_rate=stochastic_dropout_rate) + x = self._conv_head(x) + return x + + def forward(self, x: Tensor) -> Tensor: + """Returns efficientnet image features.""" + return self.extract_features(x) diff --git a/text_recognizer/networks/efficientnet/mbconv.py b/text_recognizer/networks/efficientnet/mbconv.py new file mode 100644 index 0000000..4b051eb --- /dev/null +++ b/text_recognizer/networks/efficientnet/mbconv.py @@ -0,0 +1,240 @@ +"""Mobile inverted residual block.""" +from typing import Optional, Tuple, Union + +import attr +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth + + +def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: + """Converts int to tuple.""" + return (stride,) * 2 if isinstance(stride, int) else stride + + +@attr.s(eq=False) +class BaseModule(nn.Module): + """Base sub module class.""" + + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + block: nn.Sequential = attr.ib(init=False) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def __attrs_post_init__(self) -> None: + self._build() + + def _build(self) -> None: + pass + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + return self.block(x) + + +@attr.s(auto_attribs=True, eq=False) +class InvertedBottleneck(BaseModule): + """Inverted bottleneck module.""" + + in_channels: int = attr.ib() + out_channels: int = attr.ib() + + def _build(self) -> None: + self.block = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d( + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, + ), + nn.Mish(inplace=True), + ) + + +@attr.s(auto_attribs=True, eq=False) +class Depthwise(BaseModule): + """Depthwise convolution module.""" + + channels: int = attr.ib() + kernel_size: int = attr.ib() + stride: int = attr.ib() + + def _build(self) -> None: + self.block = nn.Sequential( + nn.Conv2d( + in_channels=self.channels, + out_channels=self.channels, + kernel_size=self.kernel_size, + stride=self.stride, + groups=self.channels, + bias=False, + ), + nn.BatchNorm2d( + num_features=self.channels, momentum=self.bn_momentum, eps=self.bn_eps + ), + nn.Mish(inplace=True), + ) + + +@attr.s(auto_attribs=True, eq=False) +class SqueezeAndExcite(BaseModule): + """Sequeeze and excite module.""" + + in_channels: int = attr.ib() + channels: int = attr.ib() + se_ratio: float = attr.ib() + + def _build(self) -> None: + num_squeezed_channels = max(1, int(self.in_channels * self.se_ratio)) + self.block = nn.Sequential( + nn.Conv2d( + in_channels=self.channels, + out_channels=num_squeezed_channels, + kernel_size=1, + ), + nn.Mish(inplace=True), + nn.Conv2d( + in_channels=num_squeezed_channels, + out_channels=self.channels, + kernel_size=1, + ), + ) + + +@attr.s(auto_attribs=True, eq=False) +class Pointwise(BaseModule): + """Pointwise module.""" + + in_channels: int = attr.ib() + out_channels: int = attr.ib() + + def _build(self) -> None: + self.block = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + bias=False, + ), + nn.BatchNorm2d( + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, + ), + ) + + +@attr.s(eq=False) +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck block.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: Optional[InvertedBottleneck] = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" + if self.stride == (2, 2): + return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() + + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio + self._inverted_bottleneck = ( + InvertedBottleneck( + in_channels=self.in_channels, + out_channels=inner_channels, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + if self.expand_ratio != 1 + else None + ) + + self._depthwise = Depthwise( + channels=inner_channels, + kernel_size=self.kernel_size, + stride=self.stride, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + + self._squeeze_excite = ( + SqueezeAndExcite( + in_channels=self.in_channels, + channels=inner_channels, + se_ratio=self.se_ratio, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + if has_se + else None + ) + + self._pointwise = Pointwise( + in_channels=inner_channels, + out_channels=self.out_channels, + bn_momentum=self.bn_momentum, + bn_eps=self.bn_eps, + ) + + def _stochastic_depth( + self, x: Tensor, residual: Tensor, stochastic_dropout_rate: Optional[float] + ) -> Tensor: + if self.stride == (1, 1) and self.in_channels == self.out_channels: + if stochastic_dropout_rate: + x = stochastic_depth( + x, p=stochastic_dropout_rate, training=self.training + ) + x += residual + return x + + def forward( + self, x: Tensor, stochastic_dropout_rate: Optional[float] = None + ) -> Tensor: + """Forward pass.""" + residual = x + if self._inverted_bottleneck is not None: + x = self._inverted_bottleneck(x) + + x = F.pad(x, self.pad) + x = self._depthwise(x) + + if self._squeeze_excite is not None: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._squeeze_excite(x) + x = torch.tanh(F.softplus(x_squeezed)) * x + + x = self._pointwise(x) + + # Stochastic depth + x = self._stochastic_depth(x, residual, stochastic_dropout_rate) + return x diff --git a/text_recognizer/networks/efficientnet/utils.py b/text_recognizer/networks/efficientnet/utils.py new file mode 100644 index 0000000..5234324 --- /dev/null +++ b/text_recognizer/networks/efficientnet/utils.py @@ -0,0 +1,86 @@ +"""Util functions for efficient net.""" +import math +from typing import List, Tuple + +from omegaconf import DictConfig, OmegaConf +import torch +from torch import Tensor + + +def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor: + """Stochastic connection. + + Drops the entire convolution with a given survival probability. + + Args: + x (Tensor): Input tensor. + p (float): Survival probability between 0.0 and 1.0. + training (bool): The running mode. + + Shapes: + - x: :math: `(B, C, W, H)`. + - out: :math: `(B, C, W, H)`. + + where B is the batch size, C is the number of channels, W is the width, and H + is the height. + + Returns: + out (Tensor): Output after drop connection. + """ + assert 0.0 <= p <= 1.0, "p must be in range of [0, 1]" + + if not training: + return x + + bsz = x.shape[0] + survival_prob = 1 - p + + # Generate a binary tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = survival_prob + random_tensor += torch.rand([bsz, 1, 1, 1]).type_as(x) + binary_tensor = torch.floor(random_tensor) + + out = x / survival_prob * binary_tensor + return out + + +def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + """Returns the number output filters for a block.""" + multiplier = arch[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + """Returns how many times a layer should be repeated in a block.""" + return int(math.ceil(arch[1] * repeats)) + + +def block_args() -> List[DictConfig]: + """Returns arguments for each efficientnet block.""" + keys = [ + "num_repeats", + "kernel_size", + "stride", + "expand_ratio", + "in_channels", + "out_channels", + "se_ratio", + ] + args = [ + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] + block_args_ = [] + for row in args: + block_args_.append(OmegaConf.create(dict(zip(keys, row)))) + return block_args_ |