From cf2a827db5798a245dd5207685251675d311dbec Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 30 Sep 2022 01:17:57 +0200 Subject: Fix api bug in model --- text_recognizer/models/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 3c38ced..6048901 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -112,15 +112,15 @@ class LitTransformer(LitBase): bsz = x.shape[0] # Encode image(s) to latent vectors. - z = self.network.encode(x) + img_features = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) indecies[:, 0] = start_index for Sy in range(1, self.max_output_len): - context = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(z, context) # (B, C, Sy) + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) indecies_ = torch.argmax(logits, dim=1) # (B, Sy) indecies[:, Sy : Sy + 1] = indecies_[:, -1:] -- cgit v1.2.3-70-g09d2