diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:08:46 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:08:46 +0200 |
commit | c899c05e801b5c07159353434390e10b8625fe06 (patch) | |
tree | 680be27a663ec4efa18c48934fcb62eaf1491d8d | |
parent | 9e98c19d9e218b465a7d03c1b22c1d480f065741 (diff) |
Major bug fix in attention layer
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 37ce29e..34b6101 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -37,11 +37,10 @@ class Attention(nn.Module): self.scale = self.dim ** -0.5 inner_dim = self.dim * self.dim_head - # Attnetion - self.qkv_fn = nn.Sequential( - nn.Linear(self.dim, 3 * inner_dim, bias=False), - Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), - ) + self.query = nn.Linear(self.dim, inner_dim, bias=False) + self.key = nn.Linear(self.dim, inner_dim, bias=False) + self.value = nn.Linear(self.dim, inner_dim, bias=False) + self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward @@ -72,7 +71,7 @@ class Attention(nn.Module): q_mask = ( mask if mask is not None else torch.ones((b, n), device=device).bool() ) - k_mask = q_mask if context is not None else context_mask + k_mask = q_mask if context is None else context_mask k_mask = ( torch.ones((b, k.shape[-2]), device=device).bool() if k_mask is None @@ -104,7 +103,13 @@ class Attention(nn.Module): rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: b, n, _, device = *x.shape, x.device - q, k, v = self.qkv_fn(x) + + q = self.query(x) + k = self.key(context) if context is not None else self.key(x) + v = self.value(context) if context is not None else self.value(x) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) + ) q, k = ( self._apply_rotary_emb(q, k, rotary_pos_emb) if rotary_pos_emb is not None |