summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/util.py
blob: 85094f164e9badd64efabc7dfafe2f8b4af19369 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Miscellaneous neural network utility functionality."""
from typing import Type

from torch import nn


def activation_function(activation: str) -> Type[nn.Module]:
    """Returns the callable activation function."""
    activation_fns = nn.ModuleDict(
        [
            ["elu", nn.ELU(inplace=True)],
            ["gelu", nn.GELU()],
            ["glu", nn.GLU()],
            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
            ["none", nn.Identity()],
            ["relu", nn.ReLU(inplace=True)],
            ["selu", nn.SELU(inplace=True)],
            ["mish", nn.Mish(inplace=True)],
        ]
    )
    return activation_fns[activation.lower()]