summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:20:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:20:02 +0200
commit01b11ead9470b40ca24e41dca59ac6a8b3f65186 (patch)
tree8614b893b0293ecdcaff0689cd398ef26ef0b9eb
parentedeb58de74a9fd0ef36e1f53a52370efa66107ce (diff)
Fix multihead local attention
-rw-r--r--text_recognizer/networks/transformer/local_attention.py44
1 files changed, 26 insertions, 18 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index 134089c..002069c 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -54,7 +54,7 @@ class LocalAttention(nn.Module):
rotary_pos_emb: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Computes windowed attention."""
- b, n, _ = x.shape
+ b, n, d = x.shape
if not n % self.window_size:
RuntimeError(
f"Sequence length {n} must be divisable with window size {self.window_size}"
@@ -75,25 +75,30 @@ class LocalAttention(nn.Module):
num_windows = n // self.window_size
# Compute buckets
- b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
+ b_n = (
+ torch.arange(self.num_heads * n)
+ .type_as(q)
+ .reshape(1, self.num_heads, num_windows, self.window_size)
+ )
bq, bk, bv = map(
- lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v)
+ lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1),
+ (q, k, v),
)
- bk = look_around(bk, backward=self.backward)
- bv = look_around(bv, backward=self.backward)
- bq_k = look_around(b_n, backward=self.backward)
+ bk = look_around(bk, backward=self.look_back)
+ bv = look_around(bv, backward=self.look_back)
+ bq_k = look_around(b_n, backward=self.look_back)
# Compute the attention.
- energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
+ energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale
mask_value = -torch.finfo(energy.dtype).max
# Causal mask.
- causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
+ causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :]
energy = energy.masked_fill_(causal_mask, mask_value)
del causal_mask
- bucket_mask = bq_k[:, :, None, :] == -1
+ bucket_mask = bq_k[:, :, :, None, :] == -1
energy.masked_fill_(bucket_mask, mask_value)
del bucket_mask
@@ -101,17 +106,18 @@ class LocalAttention(nn.Module):
b,
energy=energy,
mask=mask,
- backward=self.backward,
+ backward=self.look_back,
window_size=self.window_size,
num_windows=num_windows,
+ num_heads=self.num_heads,
mask_value=mask_value,
)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
- out = einsum("b h i j, b h j d -> b h i d", attn, bv)
- out = rearrange(out, "b h n d -> b n (h d)")
+ out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv)
+ out = out.reshape(-1, n, d)
out = self.fc(out)
return out, attn
@@ -134,12 +140,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
return t.expand(*expand_shape)
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
+def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor:
"""Apply windowing."""
- n = x.shape[1]
+ n = x.shape[2]
dims = (len(x.shape) - dim) * (0, 0)
x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
- tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
+ tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)]
return torch.cat(tensors, dim=dim)
@@ -150,15 +156,17 @@ def apply_input_mask(
backward: int,
window_size: int,
num_windows: int,
+ num_heads: int,
mask_value: Tensor,
) -> Tensor:
"""Applies input mask to energy tensor."""
h = b // mask.shape[0]
- mask = mask.reshape(-1, window_size, num_windows)
+ mask = torch.cat([mask] * num_heads)
+ mask = mask.reshape(-1, num_heads, num_windows, window_size)
mq = mk = mask
mk = look_around(mk, pad_value=False, backward=backward)
- mask = mq[:, :, :, None] * mk[:, :, None, :]
- mask = merge_dims(0, 1, expand_dim(mask, 1, h))
+ mask = mq[:, :, :, :, None] * mk[:, :, :, None, :]
+ mask = merge_dims(1, 2, expand_dim(mask, 2, h))
energy.masked_fill_(~mask, mask_value)
del mask
return energy