summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/local_attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:35:54 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-01 00:35:54 +0100
commit04e818853289d4f7cdddb3f09164636985dc5b1d (patch)
tree7be6741faf29fa11dda9d408e55c796076bb3a9f /text_recognizer/networks/transformer/local_attention.py
parent5b2c729e819d1e1e5a6752a3952592259ea48f8a (diff)
Fix bugs in local attention
Diffstat (limited to 'text_recognizer/networks/transformer/local_attention.py')
-rw-r--r--text_recognizer/networks/transformer/local_attention.py138
1 files changed, 85 insertions, 53 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index 002069c..628cf7e 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -8,14 +8,16 @@ from operator import mul
from typing import Optional, Tuple
import attr
-from einops import rearrange
import torch
from torch import einsum
from torch import nn
from torch import Tensor
import torch.nn.functional as F
-from text_recognizer.networks.transformer.attention import apply_rotary_emb
+from text_recognizer.networks.transformer.embeddings.rotary import (
+ RotaryEmbedding,
+ rotate_half,
+)
@attr.s(eq=False)
@@ -28,6 +30,7 @@ class LocalAttention(nn.Module):
window_size: int = attr.ib(default=128)
look_back: int = attr.ib(default=1)
dropout_rate: float = attr.ib(default=0.0)
+ rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
@@ -36,69 +39,63 @@ class LocalAttention(nn.Module):
def __attrs_post_init__(self) -> None:
"""Post init constructor."""
self.scale = self.dim ** -0.5
- inner_dim = self.dim * self.dim_head
-
- self.query = nn.Linear(self.dim, inner_dim, bias=False)
- self.key = nn.Linear(self.dim, inner_dim, bias=False)
- self.value = nn.Linear(self.dim, inner_dim, bias=False)
+ inner_dim = self.num_heads * self.dim_head
+ self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False)
self.dropout = nn.Dropout(p=self.dropout_rate)
# Feedforward
self.fc = nn.Linear(inner_dim, self.dim)
- def forward(
- self,
- x: Tensor,
- mask: Optional[Tensor] = None,
- rotary_pos_emb: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
- """Computes windowed attention."""
- b, n, d = x.shape
- if not n % self.window_size:
- RuntimeError(
- f"Sequence length {n} must be divisable with window size {self.window_size}"
- )
+ def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """Convert input into query, key, and value."""
- q = self.query(x)
- k = self.key(x)
- v = self.value(x)
- q, k, v = map(
- lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
- )
- q, k, v = (
- apply_rotary_emb(q, k, v, rotary_pos_emb)
- if rotary_pos_emb is not None
- else (q, k, v,)
- )
+ def _split_heads(t: Tensor) -> Tensor:
+ return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous()
- num_windows = n // self.window_size
+ def _merge_into_batch(t: Tensor) -> Tensor:
+ return t.reshape(-1, *t.shape[-2:])
- # Compute buckets
- b_n = (
- torch.arange(self.num_heads * n)
- .type_as(q)
- .reshape(1, self.num_heads, num_windows, self.window_size)
- )
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(_split_heads, qkv)
+
+ q, k, v = map(_merge_into_batch, (q, k, v))
+
+ if self.rotary_embedding is not None:
+ embedding = self.rotary_embedding(q)
+ q, k = _apply_rotary_emb(q, k, embedding)
+ return q, k, v
+
+ def _create_buckets(
+ self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
bq, bk, bv = map(
- lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1),
- (q, k, v),
+ lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v),
)
bk = look_around(bk, backward=self.look_back)
bv = look_around(bv, backward=self.look_back)
bq_k = look_around(b_n, backward=self.look_back)
+ return b_n, bq, bk, bv, bq_k
- # Compute the attention.
- energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale
+ def _apply_masks(
+ self,
+ b: int,
+ energy: Tensor,
+ b_n: Tensor,
+ bq_k: Tensor,
+ mask: Tensor,
+ num_windows: int,
+ ) -> Tensor:
mask_value = -torch.finfo(energy.dtype).max
# Causal mask.
- causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :]
+ causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
energy = energy.masked_fill_(causal_mask, mask_value)
del causal_mask
- bucket_mask = bq_k[:, :, :, None, :] == -1
+ bucket_mask = bq_k[:, :, None, :] == -1
energy.masked_fill_(bucket_mask, mask_value)
del bucket_mask
@@ -109,20 +106,57 @@ class LocalAttention(nn.Module):
backward=self.look_back,
window_size=self.window_size,
num_windows=num_windows,
- num_heads=self.num_heads,
mask_value=mask_value,
)
+ return energy
+
+ def forward(
+ self, x: Tensor, mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """Computes windowed attention."""
+ b, n, _ = x.shape
+
+ if not n % self.window_size:
+ RuntimeError(
+ f"Sequence length {n} must be divisable with window size {self.window_size}"
+ )
+
+ num_windows = n // self.window_size
+
+ q, k, v = self._to_embeddings(x)
+ d = q.shape[-1]
+
+ # Compute buckets
+ b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
+
+ # Compute the attention.
+ energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
+ energy = self._apply_masks(b, energy, b_n, bq_k, mask, num_windows)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
- out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv)
+ out = einsum("b h i j, b h j d -> b h i d", attn, bv)
out = out.reshape(-1, n, d)
+ out = out.reshape(b, self.num_heads, n, -1).transpose(1, 2).reshape(b, n, -1)
out = self.fc(out)
return out, attn
+def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+ q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+ return q, k
+
+
+def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor:
+ shape = list(t.shape)
+ dims = len(t.shape)
+ dim = (dim + dims) % dims
+ shape[dim : dim + 1] = split_dims
+ return t.reshape(shape)
+
+
def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
"""Merge dimensions."""
shape = list(tensor.shape)
@@ -140,12 +174,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
return t.expand(*expand_shape)
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor:
+def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
"""Apply windowing."""
- n = x.shape[2]
+ n = x.shape[1]
dims = (len(x.shape) - dim) * (0, 0)
x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
- tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)]
+ tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
return torch.cat(tensors, dim=dim)
@@ -156,17 +190,15 @@ def apply_input_mask(
backward: int,
window_size: int,
num_windows: int,
- num_heads: int,
mask_value: Tensor,
) -> Tensor:
"""Applies input mask to energy tensor."""
h = b // mask.shape[0]
- mask = torch.cat([mask] * num_heads)
- mask = mask.reshape(-1, num_heads, num_windows, window_size)
+ mask = mask.reshape(-1, num_windows, window_size)
mq = mk = mask
mk = look_around(mk, pad_value=False, backward=backward)
- mask = mq[:, :, :, :, None] * mk[:, :, :, None, :]
- mask = merge_dims(1, 2, expand_dim(mask, 2, h))
+ mask = mq[:, :, :, None] * mk[:, :, None, :]
+ mask = merge_dims(0, 1, expand_dim(mask, 1, h))
energy.masked_fill_(~mask, mask_value)
del mask
return energy