summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/util.py
blob: e822c5706e03dd6115c185bbb17e402f2365376d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Miscellaneous neural network utility functionality."""
from functools import partial
from importlib import import_module
from typing import Any, Type

from torch import nn


def activation_function(activation: str) -> Type[nn.Module]:
    """Returns the callable activation function."""
    activation_fns = nn.ModuleDict(
        [
            ["elu", nn.ELU(inplace=True)],
            ["gelu", nn.GELU()],
            ["glu", nn.GLU()],
            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
            ["none", nn.Identity()],
            ["relu", nn.ReLU(inplace=True)],
            ["selu", nn.SELU(inplace=True)],
            ["mish", nn.Mish(inplace=True)],
        ]
    )
    return activation_fns[activation.lower()]


def load_partial_fn(fn: str, **kwargs: Any) -> partial:
    """Loads partial function."""
    module = import_module(".".join(fn.split(".")[:-1]))
    return partial(getattr(module, fn.split(".")[0]), **kwargs)