diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:54 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:54 +0100 |
commit | 04e818853289d4f7cdddb3f09164636985dc5b1d (patch) | |
tree | 7be6741faf29fa11dda9d408e55c796076bb3a9f | |
parent | 5b2c729e819d1e1e5a6752a3952592259ea48f8a (diff) |
Fix bugs in local attention
-rw-r--r-- | text_recognizer/networks/transformer/local_attention.py | 138 |
1 files changed, 85 insertions, 53 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py index 002069c..628cf7e 100644 --- a/text_recognizer/networks/transformer/local_attention.py +++ b/text_recognizer/networks/transformer/local_attention.py @@ -8,14 +8,16 @@ from operator import mul from typing import Optional, Tuple import attr -from einops import rearrange import torch from torch import einsum from torch import nn from torch import Tensor import torch.nn.functional as F -from text_recognizer.networks.transformer.attention import apply_rotary_emb +from text_recognizer.networks.transformer.embeddings.rotary import ( + RotaryEmbedding, + rotate_half, +) @attr.s(eq=False) @@ -28,6 +30,7 @@ class LocalAttention(nn.Module): window_size: int = attr.ib(default=128) look_back: int = attr.ib(default=1) dropout_rate: float = attr.ib(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: """Pre init constructor.""" @@ -36,69 +39,63 @@ class LocalAttention(nn.Module): def __attrs_post_init__(self) -> None: """Post init constructor.""" self.scale = self.dim ** -0.5 - inner_dim = self.dim * self.dim_head - - self.query = nn.Linear(self.dim, inner_dim, bias=False) - self.key = nn.Linear(self.dim, inner_dim, bias=False) - self.value = nn.Linear(self.dim, inner_dim, bias=False) + inner_dim = self.num_heads * self.dim_head + self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False) self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward self.fc = nn.Linear(inner_dim, self.dim) - def forward( - self, - x: Tensor, - mask: Optional[Tensor] = None, - rotary_pos_emb: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Computes windowed attention.""" - b, n, d = x.shape - if not n % self.window_size: - RuntimeError( - f"Sequence length {n} must be divisable with window size {self.window_size}" - ) + def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Convert input into query, key, and value.""" - q = self.query(x) - k = self.key(x) - v = self.value(x) - q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) - ) - q, k, v = ( - apply_rotary_emb(q, k, v, rotary_pos_emb) - if rotary_pos_emb is not None - else (q, k, v,) - ) + def _split_heads(t: Tensor) -> Tensor: + return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous() - num_windows = n // self.window_size + def _merge_into_batch(t: Tensor) -> Tensor: + return t.reshape(-1, *t.shape[-2:]) - # Compute buckets - b_n = ( - torch.arange(self.num_heads * n) - .type_as(q) - .reshape(1, self.num_heads, num_windows, self.window_size) - ) + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(_split_heads, qkv) + + q, k, v = map(_merge_into_batch, (q, k, v)) + + if self.rotary_embedding is not None: + embedding = self.rotary_embedding(q) + q, k = _apply_rotary_emb(q, k, embedding) + return q, k, v + + def _create_buckets( + self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size) bq, bk, bv = map( - lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1), - (q, k, v), + lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v), ) bk = look_around(bk, backward=self.look_back) bv = look_around(bv, backward=self.look_back) bq_k = look_around(b_n, backward=self.look_back) + return b_n, bq, bk, bv, bq_k - # Compute the attention. - energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale + def _apply_masks( + self, + b: int, + energy: Tensor, + b_n: Tensor, + bq_k: Tensor, + mask: Tensor, + num_windows: int, + ) -> Tensor: mask_value = -torch.finfo(energy.dtype).max # Causal mask. - causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :] + causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :] energy = energy.masked_fill_(causal_mask, mask_value) del causal_mask - bucket_mask = bq_k[:, :, :, None, :] == -1 + bucket_mask = bq_k[:, :, None, :] == -1 energy.masked_fill_(bucket_mask, mask_value) del bucket_mask @@ -109,20 +106,57 @@ class LocalAttention(nn.Module): backward=self.look_back, window_size=self.window_size, num_windows=num_windows, - num_heads=self.num_heads, mask_value=mask_value, ) + return energy + + def forward( + self, x: Tensor, mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Computes windowed attention.""" + b, n, _ = x.shape + + if not n % self.window_size: + RuntimeError( + f"Sequence length {n} must be divisable with window size {self.window_size}" + ) + + num_windows = n // self.window_size + + q, k, v = self._to_embeddings(x) + d = q.shape[-1] + + # Compute buckets + b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows) + + # Compute the attention. + energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale + energy = self._apply_masks(b, energy, b_n, bq_k, mask, num_windows) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) - out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv) + out = einsum("b h i j, b h j d -> b h i d", attn, bv) out = out.reshape(-1, n, d) + out = out.reshape(b, self.num_heads, n, -1).transpose(1, 2).reshape(b, n, -1) out = self.fc(out) return out, attn +def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k + + +def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor: + shape = list(t.shape) + dims = len(t.shape) + dim = (dim + dims) % dims + shape[dim : dim + 1] = split_dims + return t.reshape(shape) + + def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor: """Merge dimensions.""" shape = list(tensor.shape) @@ -140,12 +174,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor: return t.expand(*expand_shape) -def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor: +def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor: """Apply windowing.""" - n = x.shape[2] + n = x.shape[1] dims = (len(x.shape) - dim) * (0, 0) x_pad = F.pad(x, (*dims, backward, 0), value=pad_value) - tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)] + tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)] return torch.cat(tensors, dim=dim) @@ -156,17 +190,15 @@ def apply_input_mask( backward: int, window_size: int, num_windows: int, - num_heads: int, mask_value: Tensor, ) -> Tensor: """Applies input mask to energy tensor.""" h = b // mask.shape[0] - mask = torch.cat([mask] * num_heads) - mask = mask.reshape(-1, num_heads, num_windows, window_size) + mask = mask.reshape(-1, num_windows, window_size) mq = mk = mask mk = look_around(mk, pad_value=False, backward=backward) - mask = mq[:, :, :, :, None] * mk[:, :, :, None, :] - mask = merge_dims(1, 2, expand_dim(mask, 2, h)) + mask = mq[:, :, :, None] * mk[:, :, None, :] + mask = merge_dims(0, 1, expand_dim(mask, 1, h)) energy.masked_fill_(~mask, mask_value) del mask return energy |