diff options
Diffstat (limited to 'src/text_recognizer/networks/misc.py')
-rw-r--r-- | src/text_recognizer/networks/misc.py | 45 |
1 files changed, 0 insertions, 45 deletions
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py deleted file mode 100644 index 1f853e9..0000000 --- a/src/text_recognizer/networks/misc.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Miscellaneous neural network functionality.""" -from typing import Tuple, Type - -from einops import rearrange -import torch -from torch import nn - - -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the slidning window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1] - ) - return patches - - -def activation_function(activation: str) -> Type[nn.Module]: - """Returns the callable activation function.""" - activation_fns = nn.ModuleDict( - [ - ["elu", nn.ELU(inplace=True)], - ["gelu", nn.GELU()], - ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], - ["none", nn.Identity()], - ["relu", nn.ReLU(inplace=True)], - ["selu", nn.SELU(inplace=True)], - ] - ) - return activation_fns[activation.lower()] |