diff options
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 7133c26..a2d7926 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -112,11 +112,11 @@ class CNNTransformer(nn.Module): if self.max_pool is not None: src = self.max_pool(src) - if self.adaptive_pool is not None: + if self.adaptive_pool is not None and len(src.shape) == 4: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) - else: + elif len(src.shape) == 4: src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape |