diff options
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 0788b88..3220d5a 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -2,6 +2,7 @@ import math from typing import Tuple, Type +from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.transformer.layers import Decoder @@ -51,7 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_embedding = token_pos_embedding + if not decoder.has_pos_emb: + self.token_pos_embedding = token_pos_embedding + else: + self.token_pos_embedding = None + log.debug("Decoder already have positional embedding.") # Head self.head = nn.Linear( @@ -112,7 +117,11 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_embedding(trg) + trg = ( + self.token_pos_embedding(trg) + if self.token_pos_embedding is not None + else trg + ) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] |