From beab369f59c54de888e522d2f50602e758e3cc4b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 28 Oct 2021 21:20:21 +0200 Subject: Add check for position embedding --- text_recognizer/networks/conv_transformer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 0788b88..3220d5a 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -2,6 +2,7 @@ import math from typing import Tuple, Type +from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.transformer.layers import Decoder @@ -51,7 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_embedding = token_pos_embedding + if not decoder.has_pos_emb: + self.token_pos_embedding = token_pos_embedding + else: + self.token_pos_embedding = None + log.debug("Decoder already have positional embedding.") # Head self.head = nn.Linear( @@ -112,7 +117,11 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_embedding(trg) + trg = ( + self.token_pos_embedding(trg) + if self.token_pos_embedding is not None + else trg + ) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] -- cgit v1.2.3-70-g09d2