diff options
Diffstat (limited to 'text_recognizer')
3 files changed, 135 insertions, 6 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index d953c10..98d58fd 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,9 +1,98 @@ """Efficient net.""" +from typing import Tuple + from torch import nn, Tensor +from .mbconv import MBConvBlock +from .utils import ( + block_args, + calculate_output_image_size, + get_same_padding_conv2d, + round_filters, + round_repeats, +) + class EfficientNet(nn.Module): - def __init__( - self, - ) -> None: + archs = { + # width,depth0res,dropout + "b0": (1.0, 1.0, 0.2), + "b1": (1.0, 1.1, 0.2), + "b2": (1.1, 1.2, 0.3), + "b3": (1.2, 1.4, 0.3), + "b4": (1.4, 1.8, 0.4), + "b5": (1.6, 2.2, 0.4), + "b6": (1.8, 2.6, 0.5), + "b7": (2.0, 3.1, 0.5), + "b8": (2.2, 3.6, 0.5), + "l2": (4.3, 5.3, 0.5), + } + + def __init__(self, arch: str, image_size: Tuple[int, int]) -> None: super().__init__() + assert arch in self.archs, f"{arch} not a valid efficient net architecure!" + self.arch = self.archs[arch] + self.image_size = image_size + self._conv_stem: nn.Sequential = None + self._blocks: nn.Sequential = None + self._conv_head: nn.Sequential = None + self._build() + + def _build(self) -> None: + _block_args = block_args() + in_channels = 1 # BW + out_channels = round_filters(32, self.arch) + Conv2d = get_same_padding_conv2d(image_size=self.image_size) + self._conv_stem = nn.Sequential( + Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + bias=False, + ), + nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps), + nn.Mish(inplace=True), + ) + image_size = calculate_output_image_size(self.image_size, 2) + self._blocks = nn.ModuleList([]) + for args in _block_args: + args.in_channels = round_filters(args.in_channels, self.arch) + args.out_channels = round_filters(args.out_channels, self.arch) + args.num_repeat = round_repeats(args.num_repeat, self.arch) + + self._blocks.append( + MBConvBlock( + **args, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + image_size=image_size, + ) + ) + image_size = calculate_output_image_size(image_size, args.stride) + if args.num_repeat > 1: + args.in_channels = args.out_channels + args.stride = 1 + for _ in range(args.num_repeat - 1): + self._blocks.append( + MBConvBlock( + **args, + bn_momentum=bn_momentum, + bn_eps=bn_eps, + image_size=image_size, + ) + ) + + in_channels = args.out_channels + out_channels = round_filters(1280, self.arch) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = nn.Sequential( + Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps), + ) + + def extract_features(self, x: Tensor) -> Tensor: + x = self._conv_stem(x) + + def forward(self, x: Tensor) -> Tensor: + pass diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 602aeb7..fbb3f22 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,5 +1,5 @@ """Mobile inverted residual block.""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch from torch import nn, Tensor @@ -20,15 +20,15 @@ class MBConvBlock(nn.Module): bn_momentum: float, bn_eps: float, se_ratio: float, - id_skip: bool, expand_ratio: int, image_size: Optional[Tuple[int, int]], + *args: Any, + **kwargs: Any, ) -> None: super().__init__() self.kernel_size = kernel_size self.bn_momentum = bn_momentum self.bn_eps = bn_eps - self.id_skip = id_skip self.in_channels = self.in_channels self.out_channels = out_channels diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 4b4a787..ff52485 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -3,6 +3,7 @@ from functools import partial import math from typing import Any, Optional, Tuple, Type +from omegaconf import OmegaConf import torch from torch import nn, Tensor import torch.functional as F @@ -139,3 +140,42 @@ class Conv2dStaticSamePadding(nn.Conv2d): self.groups, ) return x + + +def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + multiplier = arch[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + return int(math.ceil(arch[1] * repeats)) + + +def block_args(): + keys = [ + "num_repeats", + "kernel_size", + "strides", + "expand_ratio", + "in_channels", + "out_channels", + "se_ratio", + ] + args = [ + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] + block_args_ = [] + for row in args: + block_args_.append(OmegaConf.create(dict(zip(keys, row)))) + return block_args_ |