summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:25 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:12:25 +0100
commitc64c85c36e67a2bae07cac1adeef70e82e69225c (patch)
tree4c1def0c00475d94cb1aeb8ae571b6a89f6c2531 /text_recognizer/networks
parent113153671b5ab9ff613a03dbbfcf4266e269bd9f (diff)
Update output shape from attn module
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/attention.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index b73fec0..54ef5e2 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -53,7 +53,7 @@ class Attention(nn.Module):
context: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
+ ) -> Tensor:
"""Computes the attention."""
b, n, _, device = *x.shape, x.device
@@ -81,7 +81,7 @@ class Attention(nn.Module):
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.fc(out)
- return out, attn
+ return out
def _apply_rotary_emb(