diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:25:04 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-08 22:25:04 +0200 |
commit | 53bfdaa0b4d3a04c5f2a274c5657ada81e9bf135 (patch) | |
tree | 2509e7bb5a25311788d654e3a2e665578f9c9403 /text_recognizer/networks/encoders/efficientnet | |
parent | d717f9e3e4dd17351f08c5822cb90d055c4513cc (diff) |
Add comments
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet')
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/efficientnet.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/encoders/efficientnet/utils.py | 10 |
2 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index fb4f002..59598b5 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -10,6 +10,7 @@ from .utils import ( class EfficientNet(nn.Module): + # TODO: attr archs = { # width,depth0res,dropout "b0": (1.0, 1.0, 0.2), diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 6f293db..5234324 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -1,9 +1,8 @@ """Util functions for efficient net.""" -from functools import partial import math -from typing import Any, Optional, Union, Tuple, Type +from typing import List, Tuple -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import Tensor @@ -46,6 +45,7 @@ def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor: def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + """Returns the number output filters for a block.""" multiplier = arch[0] divisor = 8 filters *= multiplier @@ -56,10 +56,12 @@ def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + """Returns how many times a layer should be repeated in a block.""" return int(math.ceil(arch[1] * repeats)) -def block_args(): +def block_args() -> List[DictConfig]: + """Returns arguments for each efficientnet block.""" keys = [ "num_repeats", "kernel_size", |