summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/conv_transformer.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ba43b31..714792e 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -99,12 +99,12 @@ class ConvTransformer(nn.Module):
z = z.permute(0, 2, 1)
return z
- def decode(self, z: Tensor, context: Tensor) -> Tensor:
+ def decode(self, src: Tensor, trg: Tensor) -> Tensor:
"""Decodes latent images embedding into word pieces.
Args:
- z (Tensor): Latent images embedding.
- context (Tensor): Word embeddings.
+ src (Tensor): Latent images embedding.
+ trg (Tensor): Word embeddings.
Shapes:
- z: :math: `(B, Sx, E)`
@@ -116,11 +116,11 @@ class ConvTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- context = context.long()
- context_mask = context != self.pad_index
- context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
- context = self.token_pos_encoder(context)
- out = self.decoder(x=context, context=z, mask=context_mask)
+ trg = trg.long()
+ trg_mask = trg != self.pad_index
+ trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = self.token_pos_encoder(trg)
+ out = self.decoder(x=trg, context=src, mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits