summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-24 00:02:47 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-24 00:02:47 +0200
commit1d7f674236d0622addc243d15c05a1dd30ca8121 (patch)
tree53d57dcfb4f2bcc8fef010012db08b7bde2a1559 /text_recognizer/networks/encoders/efficientnet
parent038195369b3909feeeceb006d52f3af11e3081df (diff)
Still working on efficientnet
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py95
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py6
-rw-r--r--text_recognizer/networks/encoders/efficientnet/utils.py40
3 files changed, 135 insertions, 6 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index d953c10..98d58fd 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,9 +1,98 @@
"""Efficient net."""
+from typing import Tuple
+
from torch import nn, Tensor
+from .mbconv import MBConvBlock
+from .utils import (
+ block_args,
+ calculate_output_image_size,
+ get_same_padding_conv2d,
+ round_filters,
+ round_repeats,
+)
+
class EfficientNet(nn.Module):
- def __init__(
- self,
- ) -> None:
+ archs = {
+ # width,depth0res,dropout
+ "b0": (1.0, 1.0, 0.2),
+ "b1": (1.0, 1.1, 0.2),
+ "b2": (1.1, 1.2, 0.3),
+ "b3": (1.2, 1.4, 0.3),
+ "b4": (1.4, 1.8, 0.4),
+ "b5": (1.6, 2.2, 0.4),
+ "b6": (1.8, 2.6, 0.5),
+ "b7": (2.0, 3.1, 0.5),
+ "b8": (2.2, 3.6, 0.5),
+ "l2": (4.3, 5.3, 0.5),
+ }
+
+ def __init__(self, arch: str, image_size: Tuple[int, int]) -> None:
super().__init__()
+ assert arch in self.archs, f"{arch} not a valid efficient net architecure!"
+ self.arch = self.archs[arch]
+ self.image_size = image_size
+ self._conv_stem: nn.Sequential = None
+ self._blocks: nn.Sequential = None
+ self._conv_head: nn.Sequential = None
+ self._build()
+
+ def _build(self) -> None:
+ _block_args = block_args()
+ in_channels = 1 # BW
+ out_channels = round_filters(32, self.arch)
+ Conv2d = get_same_padding_conv2d(image_size=self.image_size)
+ self._conv_stem = nn.Sequential(
+ Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+ nn.Mish(inplace=True),
+ )
+ image_size = calculate_output_image_size(self.image_size, 2)
+ self._blocks = nn.ModuleList([])
+ for args in _block_args:
+ args.in_channels = round_filters(args.in_channels, self.arch)
+ args.out_channels = round_filters(args.out_channels, self.arch)
+ args.num_repeat = round_repeats(args.num_repeat, self.arch)
+
+ self._blocks.append(
+ MBConvBlock(
+ **args,
+ bn_momentum=bn_momentum,
+ bn_eps=bn_eps,
+ image_size=image_size,
+ )
+ )
+ image_size = calculate_output_image_size(image_size, args.stride)
+ if args.num_repeat > 1:
+ args.in_channels = args.out_channels
+ args.stride = 1
+ for _ in range(args.num_repeat - 1):
+ self._blocks.append(
+ MBConvBlock(
+ **args,
+ bn_momentum=bn_momentum,
+ bn_eps=bn_eps,
+ image_size=image_size,
+ )
+ )
+
+ in_channels = args.out_channels
+ out_channels = round_filters(1280, self.arch)
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
+ self._conv_head = nn.Sequential(
+ Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(num_features=out_channels, momentum=bn_momentum, eps=bn_eps),
+ )
+
+ def extract_features(self, x: Tensor) -> Tensor:
+ x = self._conv_stem(x)
+
+ def forward(self, x: Tensor) -> Tensor:
+ pass
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 602aeb7..fbb3f22 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,5 +1,5 @@
"""Mobile inverted residual block."""
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import torch
from torch import nn, Tensor
@@ -20,15 +20,15 @@ class MBConvBlock(nn.Module):
bn_momentum: float,
bn_eps: float,
se_ratio: float,
- id_skip: bool,
expand_ratio: int,
image_size: Optional[Tuple[int, int]],
+ *args: Any,
+ **kwargs: Any,
) -> None:
super().__init__()
self.kernel_size = kernel_size
self.bn_momentum = bn_momentum
self.bn_eps = bn_eps
- self.id_skip = id_skip
self.in_channels = self.in_channels
self.out_channels = out_channels
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index 4b4a787..ff52485 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -3,6 +3,7 @@ from functools import partial
import math
from typing import Any, Optional, Tuple, Type
+from omegaconf import OmegaConf
import torch
from torch import nn, Tensor
import torch.functional as F
@@ -139,3 +140,42 @@ class Conv2dStaticSamePadding(nn.Conv2d):
self.groups,
)
return x
+
+
+def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+ multiplier = arch[0]
+ divisor = 8
+ filters *= multiplier
+ new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor)
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+ return int(math.ceil(arch[1] * repeats))
+
+
+def block_args():
+ keys = [
+ "num_repeats",
+ "kernel_size",
+ "strides",
+ "expand_ratio",
+ "in_channels",
+ "out_channels",
+ "se_ratio",
+ ]
+ args = [
+ [1, 3, (1, 1), 1, 32, 16, 0.25],
+ [2, 3, (2, 2), 6, 16, 24, 0.25],
+ [2, 5, (2, 2), 6, 24, 40, 0.25],
+ [3, 3, (2, 2), 6, 40, 80, 0.25],
+ [3, 5, (1, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
+ ]
+ block_args_ = []
+ for row in args:
+ block_args_.append(OmegaConf.create(dict(zip(keys, row))))
+ return block_args_