diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-28 21:20:21 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-28 21:20:21 +0200 |
commit | beab369f59c54de888e522d2f50602e758e3cc4b (patch) | |
tree | a64ff3b399366474a9bf7e54a0c3c182bd40f065 | |
parent | 01b11ead9470b40ca24e41dca59ac6a8b3f65186 (diff) |
Add check for position embedding
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 0788b88..3220d5a 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -2,6 +2,7 @@ import math from typing import Tuple, Type +from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.transformer.layers import Decoder @@ -51,7 +52,11 @@ class ConvTransformer(nn.Module): ) # Positional encoding for decoder tokens. - self.token_pos_embedding = token_pos_embedding + if not decoder.has_pos_emb: + self.token_pos_embedding = token_pos_embedding + else: + self.token_pos_embedding = None + log.debug("Decoder already have positional embedding.") # Head self.head = nn.Linear( @@ -112,7 +117,11 @@ class ConvTransformer(nn.Module): trg = trg.long() trg_mask = trg != self.pad_index trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) - trg = self.token_pos_embedding(trg) + trg = ( + self.token_pos_embedding(trg) + if self.token_pos_embedding is not None + else trg + ) out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] |