From 53bfdaa0b4d3a04c5f2a274c5657ada81e9bf135 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 8 Jul 2021 22:25:04 +0200 Subject: Add comments --- text_recognizer/networks/encoders/efficientnet/efficientnet.py | 1 + text_recognizer/networks/encoders/efficientnet/utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index fb4f002..59598b5 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -10,6 +10,7 @@ from .utils import ( class EfficientNet(nn.Module): + # TODO: attr archs = { # width,depth0res,dropout "b0": (1.0, 1.0, 0.2), diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py index 6f293db..5234324 100644 --- a/text_recognizer/networks/encoders/efficientnet/utils.py +++ b/text_recognizer/networks/encoders/efficientnet/utils.py @@ -1,9 +1,8 @@ """Util functions for efficient net.""" -from functools import partial import math -from typing import Any, Optional, Union, Tuple, Type +from typing import List, Tuple -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf import torch from torch import Tensor @@ -46,6 +45,7 @@ def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor: def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: + """Returns the number output filters for a block.""" multiplier = arch[0] divisor = 8 filters *= multiplier @@ -56,10 +56,12 @@ def round_filters(filters: int, arch: Tuple[float, float, float]) -> int: def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int: + """Returns how many times a layer should be repeated in a block.""" return int(math.ceil(arch[1] * repeats)) -def block_args(): +def block_args() -> List[DictConfig]: + """Returns arguments for each efficientnet block.""" keys = [ "num_repeats", "kernel_size", -- cgit v1.2.3-70-g09d2