summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/misc.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/misc.py')
-rw-r--r--src/text_recognizer/networks/misc.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
new file mode 100644
index 0000000..9440f9d
--- /dev/null
+++ b/src/text_recognizer/networks/misc.py
@@ -0,0 +1,28 @@
+"""Miscellaneous neural network functionality."""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch.nn import Unfold
+
+
+def sliding_window(
+ images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
+) -> torch.Tensor:
+ """Creates patches of an image.
+
+ Args:
+ images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
+ patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
+ stride (Tuple[int, int]): The stride of the sliding window.
+
+ Returns:
+ torch.Tensor: A tensor with the shape (batch, patches, height, width).
+
+ """
+ unfold = Unfold(kernel_size=patch_size, stride=stride)
+ patches = unfold(images)
+ patches = rearrange(
+ patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1]
+ )
+ return patches