summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-24 00:02:47 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-24 00:02:47 +0200
commit1d7f674236d0622addc243d15c05a1dd30ca8121 (patch)
tree53d57dcfb4f2bcc8fef010012db08b7bde2a1559 /text_recognizer/networks/encoders/efficientnet/utils.py
parent038195369b3909feeeceb006d52f3af11e3081df (diff)
Still working on efficientnet
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/utils.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/utils.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index 4b4a787..ff52485 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -3,6 +3,7 @@ from functools import partial
import math
from typing import Any, Optional, Tuple, Type
+from omegaconf import OmegaConf
import torch
from torch import nn, Tensor
import torch.functional as F
@@ -139,3 +140,42 @@ class Conv2dStaticSamePadding(nn.Conv2d):
self.groups,
)
return x
+
+
+def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+ multiplier = arch[0]
+ divisor = 8
+ filters *= multiplier
+ new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor)
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+
+def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+ return int(math.ceil(arch[1] * repeats))
+
+
+def block_args():
+ keys = [
+ "num_repeats",
+ "kernel_size",
+ "strides",
+ "expand_ratio",
+ "in_channels",
+ "out_channels",
+ "se_ratio",
+ ]
+ args = [
+ [1, 3, (1, 1), 1, 32, 16, 0.25],
+ [2, 3, (2, 2), 6, 16, 24, 0.25],
+ [2, 5, (2, 2), 6, 24, 40, 0.25],
+ [3, 3, (2, 2), 6, 40, 80, 0.25],
+ [3, 5, (1, 1), 6, 80, 112, 0.25],
+ [4, 5, (2, 2), 6, 112, 192, 0.25],
+ [1, 3, (1, 1), 6, 192, 320, 0.25],
+ ]
+ block_args_ = []
+ for row in args:
+ block_args_.append(OmegaConf.create(dict(zip(keys, row))))
+ return block_args_