summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:19:32 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:19:32 +0200
commitbcab7d59947b86fccae809e8193cff77eeb9a81d (patch)
tree91993347afd1f3d8ec1316e9e598bbdb2be4a0f0 /text_recognizer/networks/transformer/attention.py
parent80dcd02c52b11e944967e5e80a6562c4a8b3ed2e (diff)
Refactor attention module
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py34
1 files changed, 19 insertions, 15 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 9b33944..d687056 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -67,17 +67,11 @@ class Attention(nn.Module):
else (q, k, v,)
)
- input_mask = compute_input_mask(b, n, k, mask, context, context_mask, device)
-
- # Compute the attention
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = -torch.finfo(energy.dtype).max
-
- # Apply input mask
- if input_mask is not None:
- energy = energy.masked_fill_(~input_mask, mask_value)
- del input_mask
-
+ energy = apply_input_mask(
+ b, n, k, energy, mask, context, context_mask, mask_value, device
+ )
if self.causal:
energy = apply_causal_mask(energy, mask, mask_value, device)
@@ -92,22 +86,28 @@ class Attention(nn.Module):
def apply_rotary_emb(
q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
- l = rotary_pos_emb.shape[-1]
- (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
+ """Applies rotary embedding."""
+ emb_len = rotary_pos_emb.shape[-1]
+ (ql, qr), (kl, kr), (vl, vr) = map(
+ lambda t: (t[..., :emb_len], t[..., emb_len:]), (q, k, v)
+ )
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
return q, k, v
-def compute_input_mask(
+def apply_input_mask(
b: int,
n: int,
k: Tensor,
+ energy: Tensor,
mask: Optional[Tensor],
context: Optional[Tensor],
context_mask: Optional[Tensor],
+ mask_value: Tensor,
device: str,
-) -> Optional[Tensor]:
+) -> Tensor:
+ """Applies an input mask."""
if any(x is not None for x in (mask, context_mask)):
q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()
k_mask = q_mask if context is None else context_mask
@@ -118,13 +118,17 @@ def compute_input_mask(
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, "b j -> b () () j")
- return q_mask * k_mask
- return
+ input_mask = q_mask * k_mask
+
+ energy = energy.masked_fill_(~input_mask, mask_value)
+ del input_mask
+ return energy
def apply_causal_mask(
energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
) -> Tensor:
+ """Applies a causal mask to the energy tensor."""
i, j = energy.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")