From 53677be4ec14854ea4881b0d78730e0414c8dedd Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 9 Aug 2020 23:24:02 +0200 Subject: Working bash scripts etc. --- src/text_recognizer/networks/misc.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/text_recognizer/networks/misc.py (limited to 'src/text_recognizer/networks/misc.py') diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py new file mode 100644 index 0000000..9440f9d --- /dev/null +++ b/src/text_recognizer/networks/misc.py @@ -0,0 +1,28 @@ +"""Miscellaneous neural network functionality.""" +from typing import Tuple + +from einops import rearrange +import torch +from torch.nn import Unfold + + +def sliding_window( + images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] +) -> torch.Tensor: + """Creates patches of an image. + + Args: + images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width). + patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST. + stride (Tuple[int, int]): The stride of the sliding window. + + Returns: + torch.Tensor: A tensor with the shape (batch, patches, height, width). + + """ + unfold = Unfold(kernel_size=patch_size, stride=stride) + patches = unfold(images) + patches = rearrange( + patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1] + ) + return patches -- cgit v1.2.3-70-g09d2