summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/networks/conv_transformer.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 0788b88..3220d5a 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,6 +2,7 @@
import math
from typing import Tuple, Type
+from loguru import logger as log
from torch import nn, Tensor
from text_recognizer.networks.transformer.layers import Decoder
@@ -51,7 +52,11 @@ class ConvTransformer(nn.Module):
)
# Positional encoding for decoder tokens.
- self.token_pos_embedding = token_pos_embedding
+ if not decoder.has_pos_emb:
+ self.token_pos_embedding = token_pos_embedding
+ else:
+ self.token_pos_embedding = None
+ log.debug("Decoder already have positional embedding.")
# Head
self.head = nn.Linear(
@@ -112,7 +117,11 @@ class ConvTransformer(nn.Module):
trg = trg.long()
trg_mask = trg != self.pad_index
trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
- trg = self.token_pos_embedding(trg)
+ trg = (
+ self.token_pos_embedding(trg)
+ if self.token_pos_embedding is not None
+ else trg
+ )
out = self.decoder(x=trg, context=src, mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]