summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:36:18 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:36:18 +0100
commitb4602971794c749d7957e1c8b6c72043a1913be4 (patch)
tree2c7760c14a2964a8c7bbf6fc8abd4c638356d681
parent04e818853289d4f7cdddb3f09164636985dc5b1d (diff)
Add check for positional encoding in attn layer
-rw-r--r--text_recognizer/networks/conv_transformer.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 3220d5a..ee939e7 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,6 +1,6 @@
"""Vision transformer for character recognition."""
import math
-from typing import Tuple, Type
+from typing import Optional, Tuple, Type
from loguru import logger as log
from torch import nn, Tensor
@@ -22,7 +22,7 @@ class ConvTransformer(nn.Module):
encoder: nn.Module,
decoder: Decoder,
pixel_pos_embedding: Type[nn.Module],
- token_pos_embedding: Type[nn.Module],
+ token_pos_embedding: Optional[Type[nn.Module]] = None,
) -> None:
super().__init__()
self.input_dims = input_dims
@@ -52,11 +52,11 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- if not decoder.has_pos_emb:
+ if not self.decoder.has_pos_emb:
self.token_pos_embedding = token_pos_embedding
else:
self.token_pos_embedding = None
- log.debug("Decoder already have positional embedding.")
+ log.debug("Decoder already have a positional embedding.")
# Head
self.head = nn.Linear(