summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:17:57 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:17:57 +0200
commitcf2a827db5798a245dd5207685251675d311dbec (patch)
tree13d0caaa6bdbc57ec13740630362c6ad9fe9d8c4 /text_recognizer/models
parentc614c472707910658b86bb28b9f02062e6982999 (diff)
Fix api bug in model
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/transformer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 3c38ced..6048901 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -112,15 +112,15 @@ class LitTransformer(LitBase):
bsz = x.shape[0]
# Encode image(s) to latent vectors.
- z = self.network.encode(x)
+ img_features = self.network.encode(x)
# Create a placeholder matrix for storing outputs from the network
indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
indecies[:, 0] = start_index
for Sy in range(1, self.max_output_len):
- context = indecies[:, :Sy] # (B, Sy)
- logits = self.network.decode(z, context) # (B, C, Sy)
+ tokens = indecies[:, :Sy] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # (B, C, Sy)
indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
indecies[:, Sy : Sy + 1] = indecies_[:, -1:]