From a6d6d50ba72556fcd3ca736e8c5a9bc59516cd32 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 30 Sep 2021 23:06:36 +0200
Subject: Rename context to trg in transformer

---
 text_recognizer/networks/conv_transformer.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

(limited to 'text_recognizer')

diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index ba43b31..714792e 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -99,12 +99,12 @@ class ConvTransformer(nn.Module):
         z = z.permute(0, 2, 1)
         return z
 
-    def decode(self, z: Tensor, context: Tensor) -> Tensor:
+    def decode(self, src: Tensor, trg: Tensor) -> Tensor:
         """Decodes latent images embedding into word pieces.
 
         Args:
-            z (Tensor): Latent images embedding.
-            context (Tensor): Word embeddings.
+            src (Tensor): Latent images embedding.
+            trg (Tensor): Word embeddings.
 
         Shapes:
             - z: :math: `(B, Sx, E)`
@@ -116,11 +116,11 @@ class ConvTransformer(nn.Module):
         Returns:
             Tensor: Sequence of word piece embeddings.
         """
-        context = context.long()
-        context_mask = context != self.pad_index
-        context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
-        context = self.token_pos_encoder(context)
-        out = self.decoder(x=context, context=z, mask=context_mask)
+        trg = trg.long()
+        trg_mask = trg != self.pad_index
+        trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
+        trg = self.token_pos_encoder(trg)
+        out = self.decoder(x=trg, context=src, mask=trg_mask)
         logits = self.head(out)  # [B, Sy, T]
         logits = logits.permute(0, 2, 1)  # [B, T, Sy]
         return logits
-- 
cgit v1.2.3-70-g09d2