diff options
Diffstat (limited to 'src/text_recognizer/networks/misc.py')
-rw-r--r-- | src/text_recognizer/networks/misc.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 2fbab8f..6f61b5d 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -1,9 +1,9 @@ """Miscellaneous neural network functionality.""" -from typing import Tuple +from typing import Tuple, Type from einops import rearrange import torch -from torch.nn import Unfold +from torch import nn def sliding_window( @@ -20,10 +20,24 @@ def sliding_window( torch.Tensor: A tensor with the shape (batch, patches, height, width). """ - unfold = Unfold(kernel_size=patch_size, stride=stride) + unfold = nn.Unfold(kernel_size=patch_size, stride=stride) # Preform the slidning window, unsqueeze as the channel dimesion is lost. patches = unfold(images).unsqueeze(1) patches = rearrange( patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches + + +def activation_function(activation: str) -> Type[nn.Module]: + """Returns the callable activation function.""" + activation_fns = nn.ModuleDict( + [ + ["gelu", nn.GELU()], + ["leaky_relu", nn.LeakyReLU(negative_slope=1.0e-2, inplace=True)], + ["none", nn.Identity()], + ["relu", nn.ReLU(inplace=True)], + ["selu", nn.SELU(inplace=True)], + ] + ) + return activation_fns[activation.lower()] |