summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/util.py
blob: b31e640dd041ca3fc4f051312d200b6762b13abd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
from typing import Dict, Tuple, Type

from einops import rearrange
from loguru import logger
import torch
from torch import nn


def sliding_window(
    images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
) -> torch.Tensor:
    """Creates patches of an image.

    Args:
        images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
        patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
        stride (Tuple[int, int]): The stride of the sliding window.

    Returns:
        torch.Tensor: A tensor with the shape (batch, patches, height, width).

    """
    unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
    # Preform the slidning window, unsqueeze as the channel dimesion is lost.
    c = images.shape[1]
    patches = unfold(images)
    patches = rearrange(
        patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
    )
    return patches


def activation_function(activation: str) -> Type[nn.Module]:
    """Returns the callable activation function."""
    activation_fns = nn.ModuleDict(
        [
            ["elu", nn.ELU(inplace=True)],
            ["gelu", nn.GELU()],
            ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
            ["none", nn.Identity()],
            ["relu", nn.ReLU(inplace=True)],
            ["selu", nn.SELU(inplace=True)],
        ]
    )
    return activation_fns[activation.lower()]


def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
    """Loads a backbone network."""
    network_module = importlib.import_module("text_recognizer.networks")
    backbone_ = getattr(network_module, backbone)

    if "pretrained" in backbone_args:
        logger.info("Loading pretrained backbone.")
        checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
            "pretrained"
        )

        # Loading state directory.
        state_dict = torch.load(checkpoint_file)
        network_args = state_dict["network_args"]
        weights = state_dict["model_state"]

        # Initializes the network with trained weights.
        backbone = backbone_(**network_args)
        backbone.load_state_dict(weights)
        if "freeze" in backbone_args and backbone_args["freeze"] is True:
            for params in backbone.parameters():
                params.requires_grad = False

    else:
        backbone_ = getattr(network_module, backbone)
        backbone = backbone_(**backbone_args)

    if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
        backbone = nn.Sequential(
            *list(backbone.children())[:][: -backbone_args["remove_layers"]]
        )

    return backbone