diff options
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 9b33944..d687056 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -67,17 +67,11 @@ class Attention(nn.Module): else (q, k, v,) ) - input_mask = compute_input_mask(b, n, k, mask, context, context_mask, device) - - # Compute the attention energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale mask_value = -torch.finfo(energy.dtype).max - - # Apply input mask - if input_mask is not None: - energy = energy.masked_fill_(~input_mask, mask_value) - del input_mask - + energy = apply_input_mask( + b, n, k, energy, mask, context, context_mask, mask_value, device + ) if self.causal: energy = apply_causal_mask(energy, mask, mask_value, device) @@ -92,22 +86,28 @@ class Attention(nn.Module): def apply_rotary_emb( q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: - l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + """Applies rotary embedding.""" + emb_len = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map( + lambda t: (t[..., :emb_len], t[..., emb_len:]), (q, k, v) + ) ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) return q, k, v -def compute_input_mask( +def apply_input_mask( b: int, n: int, k: Tensor, + energy: Tensor, mask: Optional[Tensor], context: Optional[Tensor], context_mask: Optional[Tensor], + mask_value: Tensor, device: str, -) -> Optional[Tensor]: +) -> Tensor: + """Applies an input mask.""" if any(x is not None for x in (mask, context_mask)): q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool() k_mask = q_mask if context is None else context_mask @@ -118,13 +118,17 @@ def compute_input_mask( ) q_mask = rearrange(q_mask, "b i -> b () i ()") k_mask = rearrange(k_mask, "b j -> b () () j") - return q_mask * k_mask - return + input_mask = q_mask * k_mask + + energy = energy.masked_fill_(~input_mask, mask_value) + del input_mask + return energy def apply_causal_mask( energy: Tensor, mask: Tensor, mask_value: Tensor, device: str ) -> Tensor: + """Applies a causal mask to the energy tensor.""" i, j = energy.shape[-2:] r = torch.arange(i, device=device) mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") |