diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:12:25 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-03 22:12:25 +0100 |
commit | c64c85c36e67a2bae07cac1adeef70e82e69225c (patch) | |
tree | 4c1def0c00475d94cb1aeb8ae571b6a89f6c2531 /text_recognizer/networks | |
parent | 113153671b5ab9ff613a03dbbfcf4266e269bd9f (diff) |
Update output shape from attn module
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index b73fec0..54ef5e2 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -53,7 +53,7 @@ class Attention(nn.Module): context: Optional[Tensor] = None, mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """Computes the attention.""" b, n, _, device = *x.shape, x.device @@ -81,7 +81,7 @@ class Attention(nn.Module): out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) - return out, attn + return out def _apply_rotary_emb( |