summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders/efficientnet/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders/efficientnet/utils.py')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/utils.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index 6f293db..5234324 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -1,9 +1,8 @@
"""Util functions for efficient net."""
-from functools import partial
import math
-from typing import Any, Optional, Union, Tuple, Type
+from typing import List, Tuple
-from omegaconf import OmegaConf
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import Tensor
@@ -46,6 +45,7 @@ def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+ """Returns the number output filters for a block."""
multiplier = arch[0]
divisor = 8
filters *= multiplier
@@ -56,10 +56,12 @@ def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+ """Returns how many times a layer should be repeated in a block."""
return int(math.ceil(arch[1] * repeats))
-def block_args():
+def block_args() -> List[DictConfig]:
+ """Returns arguments for each efficientnet block."""
keys = [
"num_repeats",
"kernel_size",