summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/encoders
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/encoders')
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py1
-rw-r--r--text_recognizer/networks/encoders/efficientnet/utils.py10
2 files changed, 7 insertions, 4 deletions
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index fb4f002..59598b5 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -10,6 +10,7 @@ from .utils import (
class EfficientNet(nn.Module):
+ # TODO: attr
archs = {
# width,depth0res,dropout
"b0": (1.0, 1.0, 0.2),
diff --git a/text_recognizer/networks/encoders/efficientnet/utils.py b/text_recognizer/networks/encoders/efficientnet/utils.py
index 6f293db..5234324 100644
--- a/text_recognizer/networks/encoders/efficientnet/utils.py
+++ b/text_recognizer/networks/encoders/efficientnet/utils.py
@@ -1,9 +1,8 @@
"""Util functions for efficient net."""
-from functools import partial
import math
-from typing import Any, Optional, Union, Tuple, Type
+from typing import List, Tuple
-from omegaconf import OmegaConf
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import Tensor
@@ -46,6 +45,7 @@ def stochastic_depth(x: Tensor, p: float, training: bool) -> Tensor:
def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
+ """Returns the number output filters for a block."""
multiplier = arch[0]
divisor = 8
filters *= multiplier
@@ -56,10 +56,12 @@ def round_filters(filters: int, arch: Tuple[float, float, float]) -> int:
def round_repeats(repeats: int, arch: Tuple[float, float, float]) -> int:
+ """Returns how many times a layer should be repeated in a block."""
return int(math.ceil(arch[1] * repeats))
-def block_args():
+def block_args() -> List[DictConfig]:
+ """Returns arguments for each efficientnet block."""
keys = [
"num_repeats",
"kernel_size",