diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 01:17:57 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 01:17:57 +0200 |
commit | cf2a827db5798a245dd5207685251675d311dbec (patch) | |
tree | 13d0caaa6bdbc57ec13740630362c6ad9fe9d8c4 | |
parent | c614c472707910658b86bb28b9f02062e6982999 (diff) |
Fix api bug in model
-rw-r--r-- | text_recognizer/models/transformer.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 3c38ced..6048901 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -112,15 +112,15 @@ class LitTransformer(LitBase): bsz = x.shape[0] # Encode image(s) to latent vectors. - z = self.network.encode(x) + img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) indecies[:, 0] = start_index for Sy in range(1, self.max_output_len): - context = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(z, context) # (B, C, Sy) + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) indecies_ = torch.argmax(logits, dim=1) # (B, Sy) indecies[:, Sy : Sy + 1] = indecies_[:, -1:] |