diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 34b6101..4e32065 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -48,14 +48,17 @@ class Attention(nn.Module): @staticmethod def _apply_rotary_emb( - q: Tensor, k: Tensor, rotary_pos_emb: Tensor - ) -> Tuple[Tensor, Tensor]: + q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) - ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) - q = torch.cat((ql, qr), dim=-1) - k = torch.cat((kl, kr), dim=-1) - return q, k + (ql, qr), (kl, kr), (vl, vr) = map( + lambda t: (t[..., :l], t[..., l:]), (q, k, v) + ) + ql, kl, vl = map( + lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl) + ) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + return q, k, v @staticmethod def _compute_input_mask( @@ -110,11 +113,10 @@ class Attention(nn.Module): q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) - q, k = ( - self._apply_rotary_emb(q, k, rotary_pos_emb) - if rotary_pos_emb is not None - else q, - k, + q, k, v = ( + self._apply_rotary_emb(q, k, v, rotary_pos_emb) + if rotary_pos_emb is not None and context is None + else (q, k, v,) ) input_mask = self._compute_input_mask( |