From c64c85c36e67a2bae07cac1adeef70e82e69225c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 3 Nov 2021 22:12:25 +0100 Subject: Update output shape from attn module --- text_recognizer/networks/transformer/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index b73fec0..54ef5e2 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -53,7 +53,7 @@ class Attention(nn.Module): context: Optional[Tensor] = None, mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """Computes the attention.""" b, n, _, device = *x.shape, x.device @@ -81,7 +81,7 @@ class Attention(nn.Module): out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) - return out, attn + return out def _apply_rotary_emb( -- cgit v1.2.3-70-g09d2