diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-24 00:02:47 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-24 00:02:47 +0200 |
commit | 1d7f674236d0622addc243d15c05a1dd30ca8121 (patch) | |
tree | 53d57dcfb4f2bcc8fef010012db08b7bde2a1559 /text_recognizer/networks/encoders/efficientnet/utils.py | |
parent | 038195369b3909feeeceb006d52f3af11e3081df (diff) |
Still working on efficientnet
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/utils.py')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/utils.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 4b4a787..ff52485 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -3,6 +3,7 @@ from functools import partial import math from typing import Any, Optional, Tuple, Type +from omegaconf import OmegaConf import torch from torch import nn, Tensor import torch.functional as F @@ -139,3 +140,42 @@ class Conv2dStaticSamePadding(nn.Conv2d): self.groups, ) return x + + +def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + multiplier = arch[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + return int(math.ceil(arch[1] * repeats)) + + +def block_args(): + keys = [ + "num_repeats", + "kernel_size", + "strides", + "expand_ratio", + "in_channels", + "out_channels", + "se_ratio", + ] + args = [ + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] + block_args_ = [] + for row in args: + block_args_.append(OmegaConf.create(dict(zip(keys, row)))) + return block_args_ |