From 01b11ead9470b40ca24e41dca59ac6a8b3f65186 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 28 Oct 2021 21:20:02 +0200 Subject: Fix multihead local attention --- .../networks/transformer/local_attention.py | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index 134089c..002069c 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -54,7 +54,7 @@ class LocalAttention(nn.Module): rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Computes windowed attention.""" - b, n, _ = x.shape + b, n, d = x.shape if not n % self.window_size: RuntimeError( f"Sequence length {n} must be divisable with window size {self.window_size}" @@ -75,25 +75,30 @@ class LocalAttention(nn.Module): num_windows = n // self.window_size # Compute buckets - b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size) + b_n = ( + torch.arange(self.num_heads * n) + .type_as(q) + .reshape(1, self.num_heads, num_windows, self.window_size) + ) bq, bk, bv = map( - lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v) + lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1), + (q, k, v), ) - bk = look_around(bk, backward=self.backward) - bv = look_around(bv, backward=self.backward) - bq_k = look_around(b_n, backward=self.backward) + bk = look_around(bk, backward=self.look_back) + bv = look_around(bv, backward=self.look_back) + bq_k = look_around(b_n, backward=self.look_back) # Compute the attention. - energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale + energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale mask_value = -torch.finfo(energy.dtype).max # Causal mask. - causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :] + causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :] energy = energy.masked_fill_(causal_mask, mask_value) del causal_mask - bucket_mask = bq_k[:, :, None, :] == -1 + bucket_mask = bq_k[:, :, :, None, :] == -1 energy.masked_fill_(bucket_mask, mask_value) del bucket_mask @@ -101,17 +106,18 @@ class LocalAttention(nn.Module): b, energy=energy, mask=mask, - backward=self.backward, + backward=self.look_back, window_size=self.window_size, num_windows=num_windows, + num_heads=self.num_heads, mask_value=mask_value, ) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) - out = einsum("b h i j, b h j d -> b h i d", attn, bv) - out = rearrange(out, "b h n d -> b n (h d)") + out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv) + out = out.reshape(-1, n, d) out = self.fc(out) return out, attn @@ -134,12 +140,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor: return t.expand(*expand_shape) -def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: +def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor: """Apply windowing.""" - n = x.shape[1] + n = x.shape[2] dims = (len(x.shape) - dim) * (0, 0) x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) - tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)] + tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)] return torch.cat(tensors, dim=dim) @@ -150,15 +156,17 @@ def apply_input_mask( backward: int, window_size: int, num_windows: int, + num_heads: int, mask_value: Tensor, ) -> Tensor: """Applies input mask to energy tensor.""" h = b // mask.shape[0] - mask = mask.reshape(-1, window_size, num_windows) + mask = torch.cat([mask] * num_heads) + mask = mask.reshape(-1, num_heads, num_windows, window_size) mq = mk = mask mk = look_around(mk, pad_value=False, backward=backward) - mask = mq[:, :, :, None] * mk[:, :, None, :] - mask = merge_dims(0, 1, expand_dim(mask, 1, h)) + mask = mq[:, :, :, :, None] * mk[:, :, :, None, :] + mask = merge_dims(1, 2, expand_dim(mask, 2, h)) energy.masked_fill_(~mask, mask_value) del mask return energy -- cgit v1.2.3-70-g09d2