diff options
Diffstat (limited to 'text_recognizer/networks/util.py')
-rw-r--r-- | text_recognizer/networks/util.py | 27 |
1 files changed, 1 insertions, 26 deletions
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 131a6b4..d292680 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -1,38 +1,13 @@ """Miscellaneous neural network functionality.""" import importlib from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Dict, Type -from einops import rearrange from loguru import logger import torch from torch import nn -def sliding_window( - images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] -) -> torch.Tensor: - """Creates patches of an image. - - Args: - images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). - patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. - stride (Tuple[int, int]): The stride of the sliding window. - - Returns: - torch.Tensor: A tensor with the shape (batch, patches, height, width). - - """ - unfold = nn.Unfold(kernel_size=patch_size, stride=stride) - # Preform the sliding window, unsqueeze as the channel dimesion is lost. - c = images.shape[1] - patches = unfold(images) - patches = rearrange( - patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1], - ) - return patches - - def activation_function(activation: str) -> Type[nn.Module]: """Returns the callable activation function.""" activation_fns = nn.ModuleDict( |