summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/util.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
commit9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch)
tree4fe2bcd82553c8062eb0908ae6442c123addf55d /text_recognizer/networks/util.py
parent9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff)
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'text_recognizer/networks/util.py')
-rw-r--r--text_recognizer/networks/util.py27
1 files changed, 1 insertions, 26 deletions
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 131a6b4..d292680 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,38 +1,13 @@
"""Miscellaneous neural network functionality."""
import importlib
from pathlib import Path
-from typing import Dict, Tuple, Type
+from typing import Dict, Type
-from einops import rearrange
from loguru import logger
import torch
from torch import nn
-def sliding_window(
- images: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
-) -> torch.Tensor:
- """Creates patches of an image.
-
- Args:
- images (torch.Tensor): A Torch tensor of a 4D image(s), i.e. (batch, channel, height, width).
- patch_size (Tuple[int, int]): The size of the patches to generate, e.g. 28x28 for EMNIST.
- stride (Tuple[int, int]): The stride of the sliding window.
-
- Returns:
- torch.Tensor: A tensor with the shape (batch, patches, height, width).
-
- """
- unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
- # Preform the sliding window, unsqueeze as the channel dimesion is lost.
- c = images.shape[1]
- patches = unfold(images)
- patches = rearrange(
- patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1],
- )
- return patches
-
-
def activation_function(activation: str) -> Type[nn.Module]:
"""Returns the callable activation function."""
activation_fns = nn.ModuleDict(