summaryrefslogtreecommitdiff
path: root/src/text_recognizer
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-11 23:08:56 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-11 23:08:56 +0200
commit95cbdf5bc1cc9639febda23c28d8f464c998b214 (patch)
tree435faa5645bab4c05b7824f33d8e94a0bc421b66 /src/text_recognizer
parent53677be4ec14854ea4881b0d78730e0414c8dedd (diff)
Working one the cnn lstm ctc model.
Diffstat (limited to 'src/text_recognizer')
-rw-r--r--src/text_recognizer/datasets/util.py1
-rw-r--r--src/text_recognizer/networks/misc.py5
2 files changed, 4 insertions, 2 deletions
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 321bc67..76bd85f 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -46,6 +46,7 @@ def fetch_data_loaders(
"""
def check_dataset_args(args: Dict, split: str) -> Dict:
+ """Adds train flag to the dataset args."""
args["train"] = True if split == "train" else False
return args
diff --git a/src/text_recognizer/networks/misc.py b/src/text_recognizer/networks/misc.py
index 9440f9d..2fbab8f 100644
--- a/src/text_recognizer/networks/misc.py
+++ b/src/text_recognizer/networks/misc.py
@@ -21,8 +21,9 @@ def sliding_window(
"""
unfold = Unfold(kernel_size=patch_size, stride=stride)
- patches = unfold(images)
+ # Preform the slidning window, unsqueeze as the channel dimesion is lost.
+ patches = unfold(images).unsqueeze(1)
patches = rearrange(
- patches, "b (h w) c -> b c h w", h=patch_size[0], w=patch_size[1]
+ patches, "b c (h w) t -> b t c h w", h=patch_size[0], w=patch_size[1]
)
return patches