From 87e7b7d1a4eb35df7cb4484f379c186efd981d6b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 24 Oct 2021 00:57:24 +0200 Subject: Refator attention --- text_recognizer/networks/transformer/attention.py | 104 ++++++++++------------ 1 file changed, 48 insertions(+), 56 deletions(-) (limited to 'text_recognizer/networks/transformer') diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 4e32065..e098b63 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -46,57 +46,6 @@ class Attention(nn.Module): # Feedforward self.fc = nn.Linear(inner_dim, self.dim) - @staticmethod - def _apply_rotary_emb( - q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map( - lambda t: (t[..., :l], t[..., l:]), (q, k, v) - ) - ql, kl, vl = map( - lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl) - ) - q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) - return q, k, v - - @staticmethod - def _compute_input_mask( - b: int, - n: int, - k: Tensor, - mask: Optional[Tensor], - context: Optional[Tensor], - context_mask: Optional[Tensor], - device: str, - ) -> Optional[Tensor]: - if any(x is not None for x in (mask, context_mask)): - q_mask = ( - mask if mask is not None else torch.ones((b, n), device=device).bool() - ) - k_mask = q_mask if context is None else context_mask - k_mask = ( - torch.ones((b, k.shape[-2]), device=device).bool() - if k_mask is None - else k_mask - ) - q_mask = rearrange(q_mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b j -> b () () j") - return q_mask * k_mask - return - - @staticmethod - def _apply_causal_mask( - energy: Tensor, mask: Tensor, mask_value: Tensor, device: str - ) -> Tensor: - i, j = energy.shape[-2:] - r = torch.arange(i, device=device) - mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") - mask = F.pad(mask, (j - i, 0), value=False) - energy.masked_fill_(mask, mask_value) - del mask - return energy - def forward( self, x: Tensor, @@ -114,14 +63,12 @@ class Attention(nn.Module): lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) q, k, v = ( - self._apply_rotary_emb(q, k, v, rotary_pos_emb) + apply_rotary_emb(q, k, v, rotary_pos_emb) if rotary_pos_emb is not None and context is None else (q, k, v,) ) - input_mask = self._compute_input_mask( - b, n, k, mask, context, context_mask, device - ) + input_mask = compute_input_mask(b, n, k, mask, context, context_mask, device) # Compute the attention energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale @@ -133,7 +80,7 @@ class Attention(nn.Module): del input_mask if self.causal: - energy = self._apply_causal_mask(energy, mask, mask_value, device) + energy = apply_causal_mask(energy, mask, mask_value, device) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) @@ -141,3 +88,48 @@ class Attention(nn.Module): out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) return out, attn + + +def apply_rotary_emb( + q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + return q, k, v + + +def compute_input_mask( + b: int, + n: int, + k: Tensor, + mask: Optional[Tensor], + context: Optional[Tensor], + context_mask: Optional[Tensor], + device: str, +) -> Optional[Tensor]: + if any(x is not None for x in (mask, context_mask)): + q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool() + k_mask = q_mask if context is None else context_mask + k_mask = ( + torch.ones((b, k.shape[-2]), device=device).bool() + if k_mask is None + else k_mask + ) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + return q_mask * k_mask + return + + +def apply_causal_mask( + energy: Tensor, mask: Tensor, mask_value: Tensor, device: str +) -> Tensor: + i, j = energy.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + energy.masked_fill_(mask, mask_value) + del mask + return energy -- cgit v1.2.3-70-g09d2