diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/conv_transformer.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index ba43b31..714792e 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -99,12 +99,12 @@ class ConvTransformer(nn.Module): z = z.permute(0, 2, 1) return z - def decode(self, z: Tensor, context: Tensor) -> Tensor: + def decode(self, src: Tensor, trg: Tensor) -> Tensor: """Decodes latent images embedding into word pieces. Args: - z (Tensor): Latent images embedding. - context (Tensor): Word embeddings. + src (Tensor): Latent images embedding. + trg (Tensor): Word embeddings. Shapes: - z: :math: `(B, Sx, E)` @@ -116,11 +116,11 @@ class ConvTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ - context = context.long() - context_mask = context != self.pad_index - context = self.token_embedding(context) * math.sqrt(self.hidden_dim) - context = self.token_pos_encoder(context) - out = self.decoder(x=context, context=z, mask=context_mask) + trg = trg.long() + trg_mask = trg != self.pad_index + trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.token_pos_encoder(trg) + out = self.decoder(x=trg, context=src, mask=trg_mask) logits = self.head(out) # [B, Sy, T] logits = logits.permute(0, 2, 1) # [B, T, Sy] return logits |