From 95cbdf5bc1cc9639febda23c28d8f464c998b214 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 11 Aug 2020 23:08:56 +0200 Subject: Working one the cnn lstm ctc model. --- src/text_recognizer/datasets/util.py | 1 + src/text_recognizer/networks/misc.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'src/text_recognizer') diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 321bc67..76bd85f 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -46,6 +46,7 @@ def fetch_data_loaders( """ def check_dataset_args(args: Dict, split: str) -> Dict: + """Adds train flag to the dataset args.""" args["train"] = True if split == "train" else False return args diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py index 9440f9d..2fbab8f 100644 --- a/src/text_recognizer/networks/misc.py +++ b/src/text_recognizer/networks/misc.py @@ -21,8 +21,9 @@ def sliding_window( """ unfold = Unfold(kernel_size=patch_size, stride=stride) - patches = unfold(images) + # Preform the slidning window, unsqueeze as the channel dimesion is lost. + patches = unfold(images).unsqueeze(1) patches = rearrange( - patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1] + patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1] ) return patches -- cgit v1.2.3-70-g09d2