summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/attend.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/attend.py')
-rw-r--r--text_recognizer/network/transformer/attend.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py
index 4e643fb..d2c17b1 100644
--- a/text_recognizer/network/transformer/attend.py
+++ b/text_recognizer/network/transformer/attend.py
@@ -32,7 +32,7 @@ class Attend(nn.Module):
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
return out
- def atten(
+ def attn(
self,
q: Tensor,
k: Tensor,
@@ -66,7 +66,7 @@ class Attend(nn.Module):
if self.use_flash:
return self.flash_attn(q, k, v, causal)
else:
- return self.atten(q, k, v, causal, mask)
+ return self.attn(q, k, v, causal, mask)
def apply_input_mask(