summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/misc.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/networks/misc.py
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/networks/misc.py')
-rw-r--r--src/text_recognizer/networks/misc.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 6f61b5d..cac9e78 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -22,9 +22,10 @@ def sliding_window(
"""
unfold = nn.Unfold(kernel_size=patch_size, stride=stride)
# Preform the slidning window, unsqueeze as the channel dimesion is lost.
- patches = unfold(images).unsqueeze(1)
+ c = images.shape[1]
+ patches = unfold(images)
patches = rearrange(
- patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
+ patches, "b (c h w) t -> b t c h w", c=c, h=patch_size[0], w=patch_size[1]
)
return patches