summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/conv_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:25:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-05 19:26:39 +0100
commite5e776cb7ce3486d1a9e16f6ae328f55fd20f02b (patch)
tree61ff5abc85015a720793fe724d7b65c4ca82764e /text_recognizer/networks/conv_transformer.py
parentea525029b8b0355c656280e491796b4821c491a4 (diff)
Rename mask to input_mask
Rename mask to input_mask Rename mask to input_mask
Diffstat (limited to 'text_recognizer/networks/conv_transformer.py')
-rw-r--r--text_recognizer/networks/conv_transformer.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 0c838d8..59ce814 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -118,7 +118,7 @@ class ConvTransformer(nn.Module):
if self.token_pos_embedding is not None
else trg
)
- out = self.decoder(x=trg, context=src, mask=trg_mask)
+ out = self.decoder(x=trg, context=src, input_mask=trg_mask)
logits = self.head(out) # [B, Sy, T]
logits = logits.permute(0, 2, 1) # [B, T, Sy]
return logits