From b4602971794c749d7957e1c8b6c72043a1913be4 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 1 Nov 2021 00:36:18 +0100
Subject: Add check for positional encoding in attn layer

---
 text_recognizer/networks/conv_transformer.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 3220d5a..ee939e7 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,6 +1,6 @@
 """Vision transformer for character recognition."""
 import math
-from typing import Tuple, Type
+from typing import Optional, Tuple, Type
 
 from loguru import logger as log
 from torch import nn, Tensor
@@ -22,7 +22,7 @@ class ConvTransformer(nn.Module):
         encoder: nn.Module,
         decoder: Decoder,
         pixel_pos_embedding: Type[nn.Module],
-        token_pos_embedding: Type[nn.Module],
+        token_pos_embedding: Optional[Type[nn.Module]] = None,
     ) -> None:
         super().__init__()
         self.input_dims = input_dims
@@ -52,11 +52,11 @@ class ConvTransformer(nn.Module):
         )
 
         # Positional encoding for decoder tokens.
-        if not decoder.has_pos_emb:
+        if not self.decoder.has_pos_emb:
             self.token_pos_embedding = token_pos_embedding
         else:
             self.token_pos_embedding = None
-            log.debug("Decoder already have positional embedding.")
+            log.debug("Decoder already have a positional embedding.")
 
         # Head
         self.head = nn.Linear(
-- 
cgit v1.2.3-70-g09d2