summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index b73fec0..54ef5e2 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -53,7 +53,7 @@ class Attention(nn.Module):
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
+ ) -> Tensor:
"""Computes the attention."""
b, n, _, device = *x.shape, x.device
@@ -81,7 +81,7 @@ class Attention(nn.Module):
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.fc(out)
- return out, attn
+ return out
def _apply_rotary_emb(