diff options
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 3220d5a..ee939e7 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor @@ -22,7 +22,7 @@ class ConvTransformer(nn.Module): encoder: nn.Module, decoder: Decoder, pixel_pos_embedding: Type[nn.Module], - token_pos_embedding: Type[nn.Module], + token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() self.input_dims = input_dims @@ -52,11 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - if not decoder.has_pos_emb: + if not self.decoder.has_pos_emb: self.token_pos_embedding = token_pos_embedding else: self.token_pos_embedding = None - log.debug("Decoder already have positional embedding.") + log.debug("Decoder already have a positional embedding.") # Head self.head = nn.Linear( |