From a6d6d50ba72556fcd3ca736e8c5a9bc59516cd32 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:06:36 +0200 Subject: Rename context to trg in transformer --- text_recognizer/networks/conv_transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ba43b31..714792e 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -99,12 +99,12 @@ class ConvTransformer(nn.Module): z = z.permute(0, 2, 1) return z - def decode(self, z: Tensor, context: Tensor) -> Tensor: + def decode(self, src: Tensor, trg: Tensor) -> Tensor: """Decodes latent images embedding into word pieces. Args: - z (Tensor): Latent images embedding. - context (Tensor): Word embeddings. + src (Tensor): Latent images embedding. + trg (Tensor): Word embeddings. Shapes: - z: :math: `(B, Sx, E)` @@ -116,11 +116,11 @@ class ConvTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - context = context.long() - context_mask = context != self.pad_index - context = self.token_embedding(context) * math.sqrt(self.hidden_dim) - context = self.token_pos_encoder(context) - out = self.decoder(x=context, context=z, mask=context_mask) + trg = trg.long() + trg_mask = trg != self.pad_index + trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.token_pos_encoder(trg) + out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits -- cgit v1.2.3-70-g09d2