From b4602971794c749d7957e1c8b6c72043a1913be4 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 1 Nov 2021 00:36:18 +0100 Subject: Add check for positional encoding in attn layer --- text_recognizer/networks/conv_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 3220d5a..ee939e7 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -1,6 +1,6 @@ """Vision transformer for character recognition.""" import math -from typing import Tuple, Type +from typing import Optional, Tuple, Type from loguru import logger as log from torch import nn, Tensor @@ -22,7 +22,7 @@ class ConvTransformer(nn.Module): encoder: nn.Module, decoder: Decoder, pixel_pos_embedding: Type[nn.Module], - token_pos_embedding: Type[nn.Module], + token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: super().__init__() self.input_dims = input_dims @@ -52,11 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - if not decoder.has_pos_emb: + if not self.decoder.has_pos_emb: self.token_pos_embedding = token_pos_embedding else: self.token_pos_embedding = None - log.debug("Decoder already have positional embedding.") + log.debug("Decoder already have a positional embedding.") # Head self.head = nn.Linear( -- cgit v1.2.3-70-g09d2