diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:36:18 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:36:18 +0100 |
commit | b4602971794c749d7957e1c8b6c72043a1913be4 (patch) | |
tree | 2c7760c14a2964a8c7bbf6fc8abd4c638356d681 /text_recognizer | |
parent | 04e818853289d4f7cdddb3f09164636985dc5b1d (diff) |
Add check for positional encoding in attn layer
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 3220d5a..ee939e7 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor @@ -22,7 +22,7 @@ class ConvTransformer(nn.Module): encoder: nn.Module, decoder: Decoder, pixel_pos_embedding: Type[nn.Module], - token_pos_embedding: Type[nn.Module], + token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() self.input_dims = input_dims @@ -52,11 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - if not decoder.has_pos_emb: + if not self.decoder.has_pos_emb: self.token_pos_embedding = token_pos_embedding else: self.token_pos_embedding = None - log.debug("Decoder already have positional embedding.") + log.debug("Decoder already have a positional embedding.") # Head self.head = nn.Linear( |