diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index b73fec0..54ef5e2 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -53,7 +53,7 @@ class Attention(nn.Module): context: Optional[Tensor] = None, mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """Computes the attention.""" b, n, _, device = *x.shape, x.device @@ -81,7 +81,7 @@ class Attention(nn.Module): out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) - return out, attn + return out def _apply_rotary_emb( |