From aa11bd46bbb7237680ab2e513dfb429e27de2536 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 3 Oct 2021 00:31:20 +0200 Subject: Fix rotary embedding --- text_recognizer/networks/transformer/attention.py | 26 ++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 34b6101..4e32065 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -48,14 +48,17 @@ class Attention(nn.Module): @staticmethod def _apply_rotary_emb( - q: Tensor, k: Tensor, rotary_pos_emb: Tensor - ) -> Tuple[Tensor, Tensor]: + q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) - ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) - q = torch.cat((ql, qr), dim=-1) - k = torch.cat((kl, kr), dim=-1) - return q, k + (ql, qr), (kl, kr), (vl, vr) = map( + lambda t: (t[..., :l], t[..., l:]), (q, k, v) + ) + ql, kl, vl = map( + lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl) + ) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + return q, k, v @staticmethod def _compute_input_mask( @@ -110,11 +113,10 @@ class Attention(nn.Module): q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) - q, k = ( - self._apply_rotary_emb(q, k, rotary_pos_emb) - if rotary_pos_emb is not None - else q, - k, + q, k, v = ( + self._apply_rotary_emb(q, k, v, rotary_pos_emb) + if rotary_pos_emb is not None and context is None + else (q, k, v,) ) input_mask = self._compute_input_mask( -- cgit v1.2.3-70-g09d2