summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/util.py')
-rw-r--r--src/text_recognizer/networks/util.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/util.py b/src/text_recognizer/networks/util.py
new file mode 100644
index 0000000..b31e640
--- /dev/null
+++ b/src/text_recognizer/networks/util.py
@@ -0,0 +1,83 @@
+"""Miscellaneous neural network functionality."""
+import importlib
+from pathlib import Path
+from typing import Dict, Tuple, Type
+
+from einops import rearrange
+from loguru import logger
+import torch
+from torch import nn
+
+
+def sliding_window(
+ images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+ """Creates patches of an image.
+
+ Args:
+ images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+ patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+ stride (Tuple[int, int]): The stride of the sliding window.
+
+ Returns:
+ torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+ """
+ unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
+ # Preform the slidning window, unsqueeze as the channel dimesion is lost.
+ c = images.shape[1]
+ patches = unfold(images)
+ patches = rearrange(
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
+ )
+ return patches
+
+
+def activation_function(activation: str) -> Type[nn.Module]:
+ """Returns the callable activation function."""
+ activation_fns = nn.ModuleDict(
+ [
+ ["elu", nn.ELU(inplace=True)],
+ ["gelu", nn.GELU()],
+ ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)],
+ ["none", nn.Identity()],
+ ["relu", nn.ReLU(inplace=True)],
+ ["selu", nn.SELU(inplace=True)],
+ ]
+ )
+ return activation_fns[activation.lower()]
+
+
+def configure_backbone(backbone: str, backbone_args: Dict) -> Type[nn.Module]:
+ """Loads a backbone network."""
+ network_module = importlib.import_module("text_recognizer.networks")
+ backbone_ = getattr(network_module, backbone)
+
+ if "pretrained" in backbone_args:
+ logger.info("Loading pretrained backbone.")
+ checkpoint_file = Path(__file__).resolve().parents[2] / backbone_args.pop(
+ "pretrained"
+ )
+
+ # Loading state directory.
+ state_dict = torch.load(checkpoint_file)
+ network_args = state_dict["network_args"]
+ weights = state_dict["model_state"]
+
+ # Initializes the network with trained weights.
+ backbone = backbone_(**network_args)
+ backbone.load_state_dict(weights)
+ if "freeze" in backbone_args and backbone_args["freeze"] is True:
+ for params in backbone.parameters():
+ params.requires_grad = False
+
+ else:
+ backbone_ = getattr(network_module, backbone)
+ backbone = backbone_(**backbone_args)
+
+ if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
+ backbone = nn.Sequential(
+ *list(backbone.children())[:][: -backbone_args["remove_layers"]]
+ )
+
+ return backbone