diff options
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/utils.py')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/utils.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 4b4a787..ff52485 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -3,6 +3,7 @@ from functools import partial import math from typing import Any, Optional, Tuple, Type +from omegaconf import OmegaConf import torch from torch import nn, Tensor import torch.functional as F @@ -139,3 +140,42 @@ class Conv2dStaticSamePadding(nn.Conv2d): self.groups, ) return x + + +def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + multiplier = arch[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor) + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + return int(math.ceil(arch[1] * repeats)) + + +def block_args(): + keys = [ + "num_repeats", + "kernel_size", + "strides", + "expand_ratio", + "in_channels", + "out_channels", + "se_ratio", + ] + args = [ + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] + block_args_ = [] + for row in args: + block_args_.append(OmegaConf.create(dict(zip(keys, row)))) + return block_args_ |