diff options
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/utils.py | 10 |
2 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index fb4f002..59598b5 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -10,6 +10,7 @@ from .utils import ( class EfficientNet(nn.Module): + # TODO: attr archs = { # width,depth0res,dropout "b0": (1.0, 1.0, 0.2), diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 6f293db..5234324 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -1,9 +1,8 @@ """Util functions for efficient net.""" -from functools import partial import math -from typing import Any, Optional, Union, Tuple, Type +from typing import List, Tuple -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import Tensor @@ -46,6 +45,7 @@ def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor: def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + """Returns the number output filters for a block.""" multiplier = arch[0] divisor = 8 filters *= multiplier @@ -56,10 +56,12 @@ def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + """Returns how many times a layer should be repeated in a block.""" return int(math.ceil(arch[1] * repeats)) -def block_args(): +def block_args() -> List[DictConfig]: + """Returns arguments for each efficientnet block.""" keys = [ "num_repeats", "kernel_size", |