diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/transformer.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 3c38ced..6048901 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -112,15 +112,15 @@ class LitTransformer(LitBase): bsz = x.shape[0] # Encode image(s) to latent vectors. - z = self.network.encode(x) + img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) indecies[:, 0] = start_index for Sy in range(1, self.max_output_len): - context = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(z, context) # (B, C, Sy) + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) indecies_ = torch.argmax(logits, dim=1) # (B, Sy) indecies[:, Sy : Sy + 1] = indecies_[:, -1:] |