diff options
Diffstat (limited to 'text_recognizer/networks/transformer')
| -rw-r--r-- | text_recognizer/networks/transformer/attention.py | 34 | 
1 files changed, 19 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 9b33944..d687056 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -67,17 +67,11 @@ class Attention(nn.Module):              else (q, k, v,)          ) -        input_mask = compute_input_mask(b, n, k, mask, context, context_mask, device) - -        # Compute the attention          energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale          mask_value = -torch.finfo(energy.dtype).max - -        # Apply input mask -        if input_mask is not None: -            energy = energy.masked_fill_(~input_mask, mask_value) -            del input_mask - +        energy = apply_input_mask( +            b, n, k, energy, mask, context, context_mask, mask_value, device +        )          if self.causal:              energy = apply_causal_mask(energy, mask, mask_value, device) @@ -92,22 +86,28 @@ class Attention(nn.Module):  def apply_rotary_emb(      q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor  ) -> Tuple[Tensor, Tensor, Tensor]: -    l = rotary_pos_emb.shape[-1] -    (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) +    """Applies rotary embedding.""" +    emb_len = rotary_pos_emb.shape[-1] +    (ql, qr), (kl, kr), (vl, vr) = map( +        lambda t: (t[..., :emb_len], t[..., emb_len:]), (q, k, v) +    )      ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))      q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))      return q, k, v -def compute_input_mask( +def apply_input_mask(      b: int,      n: int,      k: Tensor, +    energy: Tensor,      mask: Optional[Tensor],      context: Optional[Tensor],      context_mask: Optional[Tensor], +    mask_value: Tensor,      device: str, -) -> Optional[Tensor]: +) -> Tensor: +    """Applies an input mask."""      if any(x is not None for x in (mask, context_mask)):          q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()          k_mask = q_mask if context is None else context_mask @@ -118,13 +118,17 @@ def compute_input_mask(          )          q_mask = rearrange(q_mask, "b i -> b () i ()")          k_mask = rearrange(k_mask, "b j -> b () () j") -        return q_mask * k_mask -    return +        input_mask = q_mask * k_mask + +        energy = energy.masked_fill_(~input_mask, mask_value) +        del input_mask +    return energy  def apply_causal_mask(      energy: Tensor, mask: Tensor, mask_value: Tensor, device: str  ) -> Tensor: +    """Applies a causal mask to the energy tensor."""      i, j = energy.shape[-2:]      r = torch.arange(i, device=device)      mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")  |