summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-03 00:31:20 +0200
commitaa11bd46bbb7237680ab2e513dfb429e27de2536 (patch)
treec0a2a05e4d97432b96ef0f75f2783baf6ec5b500 /text_recognizer/networks
parent3a21c29e2eff4378c63717f8920ca3ccbfef013c (diff)
Fix rotary embedding
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/transformer/attention.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 34b6101..4e32065 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -48,14 +48,17 @@ class Attention(nn.Module):
@staticmethod
def _apply_rotary_emb(
- q: Tensor, k: Tensor, rotary_pos_emb: Tensor
- ) -> Tuple[Tensor, Tensor]:
+ q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor]:
l = rotary_pos_emb.shape[-1]
- (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k))
- ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb)
- q = torch.cat((ql, qr), dim=-1)
- k = torch.cat((kl, kr), dim=-1)
- return q, k
+ (ql, qr), (kl, kr), (vl, vr) = map(
+ lambda t: (t[..., :l], t[..., l:]), (q, k, v)
+ )
+ ql, kl, vl = map(
+ lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)
+ )
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
+ return q, k, v
@staticmethod
def _compute_input_mask(
@@ -110,11 +113,10 @@ class Attention(nn.Module):
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
)
- q, k = (
- self._apply_rotary_emb(q, k, rotary_pos_emb)
- if rotary_pos_emb is not None
- else q,
- k,
+ q, k, v = (
+ self._apply_rotary_emb(q, k, v, rotary_pos_emb)
+ if rotary_pos_emb is not None and context is None
+ else (q, k, v,)
)
input_mask = self._compute_input_mask(