diff options
Diffstat (limited to 'src/text_recognizer/networks/util.py')
-rw-r--r-- | src/text_recognizer/networks/util.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py new file mode 100644 index 0000000..0d08506 --- /dev/null +++ b/src/text_recognizer/networks/util.py @@ -0,0 +1,83 @@ +"""Miscellaneous neural network functionality.""" +import importlib +from pathlib import Path +from typing import Dict, Tuple, Type + +from einops import rearrange +from loguru import logger +import torch +from torch import nn + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) + # Preform the slidning window, unsqueeze as the channel dimesion is lost. + c = images.shape[1] + patches = unfold(images) + patches = rearrange( + patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] + ) + return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["elu", nn.ELU(inplace=True)], + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] + + +def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]: + """Loads a backbone network.""" + network_module = importlib.import_module("text_recognizer.networks") + backbone_ = getattr(network_module, backbone) + + if "pretrained" in backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop( + "pretrained" + ) + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if "freeze" in backbone_args and backbone_args["freeze"] is True: + for params in backbone.parameters(): + params.requires_grad = False + + else: + backbone_ = getattr(network_module, backbone) + backbone = backbone_(**backbone_args) + + if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: + backbone = nn.Sequential( + *list(backbone.children())[0][: -backbone_args["remove_layers"]] + ) + + return backbone |