1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
from typing import Dict, Tuple, Type
from einops import rearrange
from loguru import logger
import torch
from torch import nn
def sliding_window(
images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
) -> torch.Tensor:
"""Creates patches of an image.
Args:
images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
stride (Tuple[int, int]): The stride of the sliding window.
Returns:
torch.Tensor: A tensor with the shape (batch, patches, height, width).
"""
unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the sliding window, unsqueeze as the channel dimesion is lost.
c = images.shape[1]
patches = unfold(images)
patches = rearrange(
patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
)
return patches
def activation_function(activation: str) -> Type[nn.Module]:
"""Returns the callable activation function."""
activation_fns = nn.ModuleDict(
[
["elu", nn.ELU(inplace=True)],
["gelu", nn.GELU()],
["glu", nn.GLU()],
["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
["selu", nn.SELU(inplace=True)],
]
)
return activation_fns[activation.lower()]
def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
"""Loads a backbone network."""
network_module = importlib.import_module("text_recognizer.networks")
backbone_ = getattr(network_module, backbone)
if "pretrained" in backbone_args:
logger.info("Loading pretrained backbone.")
checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
"pretrained"
)
# Loading state directory.
state_dict = torch.load(checkpoint_file)
network_args = state_dict["network_args"]
weights = state_dict["model_state"]
# Initializes the network with trained weights.
backbone = backbone_(**network_args)
backbone.load_state_dict(weights)
if "freeze" in backbone_args and backbone_args["freeze"] is True:
for params in backbone.parameters():
params.requires_grad = False
else:
backbone_ = getattr(network_module, backbone)
backbone = backbone_(**backbone_args)
if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
backbone = nn.Sequential(
*list(backbone.children())[:][: -backbone_args["remove_layers"]]
)
return backbone
|