From 04e818853289d4f7cdddb3f09164636985dc5b1d Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 1 Nov 2021 00:35:54 +0100
Subject: Fix bugs in local attention

---
 .../networks/transformer/local_attention.py        | 138 +++++++++++++--------
 1 file changed, 85 insertions(+), 53 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
index 002069c..628cf7e 100644
--- a/text_recognizer/networks/transformer/local_attention.py
+++ b/text_recognizer/networks/transformer/local_attention.py
@@ -8,14 +8,16 @@ from operator import mul
 from typing import Optional, Tuple
 
 import attr
-from einops import rearrange
 import torch
 from torch import einsum
 from torch import nn
 from torch import Tensor
 import torch.nn.functional as F
 
-from text_recognizer.networks.transformer.attention import apply_rotary_emb
+from text_recognizer.networks.transformer.embeddings.rotary import (
+    RotaryEmbedding,
+    rotate_half,
+)
 
 
 @attr.s(eq=False)
@@ -28,6 +30,7 @@ class LocalAttention(nn.Module):
     window_size: int = attr.ib(default=128)
     look_back: int = attr.ib(default=1)
     dropout_rate: float = attr.ib(default=0.0)
+    rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
 
     def __attrs_pre_init__(self) -> None:
         """Pre init constructor."""
@@ -36,69 +39,63 @@ class LocalAttention(nn.Module):
     def __attrs_post_init__(self) -> None:
         """Post init constructor."""
         self.scale = self.dim ** -0.5
-        inner_dim = self.dim * self.dim_head
-
-        self.query = nn.Linear(self.dim, inner_dim, bias=False)
-        self.key = nn.Linear(self.dim, inner_dim, bias=False)
-        self.value = nn.Linear(self.dim, inner_dim, bias=False)
+        inner_dim = self.num_heads * self.dim_head
 
+        self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False)
         self.dropout = nn.Dropout(p=self.dropout_rate)
 
         # Feedforward
         self.fc = nn.Linear(inner_dim, self.dim)
 
-    def forward(
-        self,
-        x: Tensor,
-        mask: Optional[Tensor] = None,
-        rotary_pos_emb: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Tensor]:
-        """Computes windowed attention."""
-        b, n, d = x.shape
-        if not n % self.window_size:
-            RuntimeError(
-                f"Sequence length {n} must be divisable with window size {self.window_size}"
-            )
+    def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+        """Convert input into query, key, and value."""
 
-        q = self.query(x)
-        k = self.key(x)
-        v = self.value(x)
-        q, k, v = map(
-            lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
-        )
-        q, k, v = (
-            apply_rotary_emb(q, k, v, rotary_pos_emb)
-            if rotary_pos_emb is not None
-            else (q, k, v,)
-        )
+        def _split_heads(t: Tensor) -> Tensor:
+            return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous()
 
-        num_windows = n // self.window_size
+        def _merge_into_batch(t: Tensor) -> Tensor:
+            return t.reshape(-1, *t.shape[-2:])
 
-        # Compute buckets
-        b_n = (
-            torch.arange(self.num_heads * n)
-            .type_as(q)
-            .reshape(1, self.num_heads, num_windows, self.window_size)
-        )
+        qkv = self.to_qkv(x).chunk(3, dim=-1)
+        q, k, v = map(_split_heads, qkv)
+
+        q, k, v = map(_merge_into_batch, (q, k, v))
+
+        if self.rotary_embedding is not None:
+            embedding = self.rotary_embedding(q)
+            q, k = _apply_rotary_emb(q, k, embedding)
+        return q, k, v
+
+    def _create_buckets(
+        self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
+    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+        b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
         bq, bk, bv = map(
-            lambda t: t.reshape(b, self.num_heads, num_windows, self.window_size, -1),
-            (q, k, v),
+            lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v),
         )
 
         bk = look_around(bk, backward=self.look_back)
         bv = look_around(bv, backward=self.look_back)
         bq_k = look_around(b_n, backward=self.look_back)
+        return b_n, bq, bk, bv, bq_k
 
-        # Compute the attention.
-        energy = einsum("b h n i d, b h n j d -> b h n i j", bq, bk) * self.scale
+    def _apply_masks(
+        self,
+        b: int,
+        energy: Tensor,
+        b_n: Tensor,
+        bq_k: Tensor,
+        mask: Tensor,
+        num_windows: int,
+    ) -> Tensor:
         mask_value = -torch.finfo(energy.dtype).max
 
         # Causal mask.
-        causal_mask = b_n[:, :, :, :, None] < bq_k[:, :, :, None, :]
+        causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
         energy = energy.masked_fill_(causal_mask, mask_value)
         del causal_mask
 
-        bucket_mask = bq_k[:, :, :, None, :] == -1
+        bucket_mask = bq_k[:, :, None, :] == -1
         energy.masked_fill_(bucket_mask, mask_value)
         del bucket_mask
 
@@ -109,20 +106,57 @@ class LocalAttention(nn.Module):
             backward=self.look_back,
             window_size=self.window_size,
             num_windows=num_windows,
-            num_heads=self.num_heads,
             mask_value=mask_value,
         )
+        return energy
+
+    def forward(
+        self, x: Tensor, mask: Optional[Tensor] = None,
+    ) -> Tuple[Tensor, Tensor]:
+        """Computes windowed attention."""
+        b, n, _ = x.shape
+
+        if not n % self.window_size:
+            RuntimeError(
+                f"Sequence length {n} must be divisable with window size {self.window_size}"
+            )
+
+        num_windows = n // self.window_size
+
+        q, k, v = self._to_embeddings(x)
+        d = q.shape[-1]
+
+        # Compute buckets
+        b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
+
+        # Compute the attention.
+        energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
+        energy = self._apply_masks(b, energy, b_n, bq_k, mask, num_windows)
 
         attn = F.softmax(energy, dim=-1)
         attn = self.dropout(attn)
 
-        out = einsum("b h n i j, b h n j d -> b h n i d", attn, bv)
+        out = einsum("b h i j, b h j d -> b h i d", attn, bv)
         out = out.reshape(-1, n, d)
 
+        out = out.reshape(b, self.num_heads, n, -1).transpose(1, 2).reshape(b, n, -1)
         out = self.fc(out)
         return out, attn
 
 
+def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+    return q, k
+
+
+def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor:
+    shape = list(t.shape)
+    dims = len(t.shape)
+    dim = (dim + dims) % dims
+    shape[dim : dim + 1] = split_dims
+    return t.reshape(shape)
+
+
 def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
     """Merge dimensions."""
     shape = list(tensor.shape)
@@ -140,12 +174,12 @@ def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
     return t.expand(*expand_shape)
 
 
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 3) -> Tensor:
+def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
     """Apply windowing."""
-    n = x.shape[2]
+    n = x.shape[1]
     dims = (len(x.shape) - dim) * (0, 0)
     x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
-    tensors = [x_pad[:, :, ind : (ind + n), ...] for ind in range(backward + 1)]
+    tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
     return torch.cat(tensors, dim=dim)
 
 
@@ -156,17 +190,15 @@ def apply_input_mask(
     backward: int,
     window_size: int,
     num_windows: int,
-    num_heads: int,
     mask_value: Tensor,
 ) -> Tensor:
     """Applies input mask to energy tensor."""
     h = b // mask.shape[0]
-    mask = torch.cat([mask] * num_heads)
-    mask = mask.reshape(-1, num_heads, num_windows, window_size)
+    mask = mask.reshape(-1, num_windows, window_size)
     mq = mk = mask
     mk = look_around(mk, pad_value=False, backward=backward)
-    mask = mq[:, :, :, :, None] * mk[:, :, :, None, :]
-    mask = merge_dims(1, 2, expand_dim(mask, 2, h))
+    mask = mq[:, :, :, None] * mk[:, :, None, :]
+    mask = merge_dims(0, 1, expand_dim(mask, 1, h))
     energy.masked_fill_(~mask, mask_value)
     del mask
     return energy
-- 
cgit v1.2.3-70-g09d2