summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:08:46 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:08:46 +0200
commitc899c05e801b5c07159353434390e10b8625fe06 (patch)
tree680be27a663ec4efa18c48934fcb62eaf1491d8d
parent9e98c19d9e218b465a7d03c1b22c1d480f065741 (diff)
Major bug fix in attention layer
-rw-r--r--text_recognizer/networks/transformer/attention.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 37ce29e..34b6101 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -37,11 +37,10 @@ class Attention(nn.Module):
self.scale = self.dim ** -0.5
inner_dim = self.dim * self.dim_head
- # Attnetion
- self.qkv_fn = nn.Sequential(
- nn.Linear(self.dim, 3 * inner_dim, bias=False),
- Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),
- )
+ self.query = nn.Linear(self.dim, inner_dim, bias=False)
+ self.key = nn.Linear(self.dim, inner_dim, bias=False)
+ self.value = nn.Linear(self.dim, inner_dim, bias=False)
+
self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
@@ -72,7 +71,7 @@ class Attention(nn.Module):
q_mask = (
mask if mask is not None else torch.ones((b, n), device=device).bool()
)
- k_mask = q_mask if context is not None else context_mask
+ k_mask = q_mask if context is None else context_mask
k_mask = (
torch.ones((b, k.shape[-2]), device=device).bool()
if k_mask is None
@@ -104,7 +103,13 @@ class Attention(nn.Module):
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
b, n, _, device = *x.shape, x.device
- q, k, v = self.qkv_fn(x)
+
+ q = self.query(x)
+ k = self.key(context) if context is not None else self.key(x)
+ v = self.value(context) if context is not None else self.value(x)
+ q, k, v = map(
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
+ )
q, k = (
self._apply_rotary_emb(q, k, rotary_pos_emb)
if rotary_pos_emb is not None