summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/utils.py
blob: 6f293dbcf793cbe1e81ad2d10d569ca3e3e93812 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Util functions for efficient net."""
from functools import partial
import math
from typing import Any, Optional, Union, Tuple, Type

from omegaconf import OmegaConf
import torch
from torch import Tensor


def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
    """Stochastic connection.

    Drops the entire convolution with a given survival probability.

    Args:
        x (Tensor): Input tensor.
        p (float): Survival probability between 0.0 and 1.0.
        training (bool): The running mode.

    Shapes:
        - x: :math: `(B, C, W, H)`.
        - out: :math: `(B, C, W, H)`.

        where B is the batch size, C is the number of channels, W is the width, and H
        is the height.

    Returns:
        out (Tensor): Output after drop connection.
    """
    assert 0.0 <= p <= 1.0, "p must be in range of [0, 1]"

    if not training:
        return x

    bsz = x.shape[0]
    survival_prob = 1 - p

    # Generate a binary tensor mask according to probability (p for 0, 1-p for 1)
    random_tensor = survival_prob
    random_tensor += torch.rand([bsz, 1, 1, 1]).type_as(x)
    binary_tensor = torch.floor(random_tensor)

    out = x / survival_prob * binary_tensor
    return out


def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
    multiplier = arch[0]
    divisor = 8
    filters *= multiplier
    new_filters = max(divisor, (filters + divisor // 2) // divisor * divisor)
    if new_filters < 0.9 * filters:
        new_filters += divisor
    return int(new_filters)


def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
    return int(math.ceil(arch[1] * repeats))


def block_args():
    keys = [
        "num_repeats",
        "kernel_size",
        "stride",
        "expand_ratio",
        "in_channels",
        "out_channels",
        "se_ratio",
    ]
    args = [
        [1, 3, (1, 1), 1, 32, 16, 0.25],
        [2, 3, (2, 2), 6, 16, 24, 0.25],
        [2, 5, (2, 2), 6, 24, 40, 0.25],
        [3, 3, (2, 2), 6, 40, 80, 0.25],
        [3, 5, (1, 1), 6, 80, 112, 0.25],
        [4, 5, (2, 2), 6, 112, 192, 0.25],
        [1, 3, (1, 1), 6, 192, 320, 0.25],
    ]
    block_args_ = []
    for row in args:
        block_args_.append(OmegaConf.create(dict(zip(keys, row))))
    return block_args_