From 2cb2c5b38f0711267fecfe9c5e10940f4b4f79fc Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 17 Nov 2021 22:42:29 +0100
Subject: Remove local attention

---
 text_recognizer/networks/transformer/layers.py     |  16 +-
 .../networks/transformer/local_attention.py        | 233 ---------------------
 2 files changed, 1 insertion(+), 248 deletions(-)
 delete mode 100644 text_recognizer/networks/transformer/local_attention.py

(limited to 'text_recognizer')

diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 8387fa4..67558ad 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -6,7 +6,6 @@ import attr
 from torch import nn, Tensor
 
 from text_recognizer.networks.transformer.attention import Attention
-from text_recognizer.networks.transformer.local_attention import LocalAttention
 from text_recognizer.networks.transformer.mlp import FeedForward
 from text_recognizer.networks.transformer.residual import Residual
 
@@ -24,18 +23,13 @@ class AttentionLayers(nn.Module):
     norm: Type[nn.Module] = attr.ib()
     ff: FeedForward = attr.ib()
     cross_attn: Optional[Attention] = attr.ib(default=None)
-    local_self_attn: Optional[LocalAttention] = attr.ib(default=None)
     pre_norm: bool = attr.ib(default=True)
-    local_depth: Optional[int] = attr.ib(default=None)
     has_pos_emb: bool = attr.ib(default=False)
     layer_types: Tuple[str, ...] = attr.ib(init=False)
     layers: nn.ModuleList = attr.ib(init=False)
 
     def __attrs_post_init__(self) -> None:
         """Post init configuration."""
-        if self.local_self_attn is not None:
-            if self.local_depth is None:
-                ValueError("Local depth has to be specified")
         self.layer_types = self._get_layer_types() * self.depth
         self.layers = self._build_network()
 
@@ -45,14 +39,8 @@ class AttentionLayers(nn.Module):
             return "a", "c", "f"
         return "a", "f"
 
-    def _self_attn_block(self, i: int) -> Type[nn.Module]:
-        if self.local_depth is not None and i < self.local_depth:
-            return deepcopy(self.local_self_attn)
-        return deepcopy(self.self_attn)
-
     def _delete(self) -> None:
         del self.self_attn
-        del self.local_self_attn
         del self.ff
         del self.norm
         del self.cross_attn
@@ -60,11 +48,9 @@ class AttentionLayers(nn.Module):
     def _build_network(self) -> nn.ModuleList:
         """Configures transformer network."""
         layers = nn.ModuleList([])
-        self_attn_depth = 0
         for layer_type in self.layer_types:
             if layer_type == "a":
-                layer = self._self_attn_block(self_attn_depth)
-                self_attn_depth += 1
+                layer = deepcopy(self.self_attn)
             elif layer_type == "c":
                 layer = deepcopy(self.cross_attn)
             elif layer_type == "f":
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
deleted file mode 100644
index a008bab..0000000
--- a/text_recognizer/networks/transformer/local_attention.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""Local attention module.
-
-Also stolen from lucidrains from here:
-https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py
-"""
-from functools import reduce
-import math
-from operator import mul
-from typing import List, Optional, Tuple
-
-import attr
-import torch
-from torch import einsum
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.embeddings.rotary import (
-    RotaryEmbedding,
-    rotate_half,
-)
-
-
-@attr.s(eq=False)
-class LocalAttention(nn.Module):
-    """Local windowed attention."""
-
-    dim: int = attr.ib()
-    num_heads: int = attr.ib()
-    dim_head: int = attr.ib(default=64)
-    window_size: int = attr.ib(default=128)
-    look_back: int = attr.ib(default=1)
-    dropout_rate: float = attr.ib(default=0.0)
-    autopad: bool = attr.ib(default=False)
-    rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
-
-    def __attrs_pre_init__(self) -> None:
-        """Pre init constructor."""
-        super().__init__()
-
-    def __attrs_post_init__(self) -> None:
-        """Post init constructor."""
-        self.scale = self.dim ** -0.5
-        inner_dim = self.num_heads * self.dim_head
-
-        self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False)
-        self.dropout = nn.Dropout(p=self.dropout_rate)
-
-        # Feedforward
-        self.fc = nn.Linear(inner_dim, self.dim)
-
-    def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
-        """Convert input into query, key, and value."""
-
-        def _split_heads(t: Tensor) -> Tensor:
-            return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous()
-
-        def _merge_into_batch(t: Tensor) -> Tensor:
-            return t.reshape(-1, *t.shape[-2:])
-
-        qkv = self.to_qkv(x).chunk(3, dim=-1)
-        q, k, v = map(_split_heads, qkv)
-        shape = q.shape
-
-        q, k, v = map(_merge_into_batch, (q, k, v))
-
-        if self.rotary_embedding is not None:
-            embedding = self.rotary_embedding(q)
-            q, k = _apply_rotary_emb(q, k, embedding)
-        return q, k, v, shape
-
-    def _create_buckets(
-        self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
-    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
-        b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
-        bq, bk, bv = map(
-            lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v),
-        )
-
-        bk = look_around(bk, backward=self.look_back)
-        bv = look_around(bv, backward=self.look_back)
-        bq_k = look_around(b_n, backward=self.look_back)
-        return b_n, bq, bk, bv, bq_k
-
-    def _apply_masks(
-        self,
-        b: int,
-        energy: Tensor,
-        b_n: Tensor,
-        bq_k: Tensor,
-        input_mask: Tensor,
-        num_windows: int,
-    ) -> Tensor:
-        mask_value = -torch.finfo(energy.dtype).max
-
-        # Causal mask.
-        causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
-        energy = energy.masked_fill_(causal_mask, mask_value)
-        del causal_mask
-
-        bucket_mask = bq_k[:, :, None, :] == -1
-        energy.masked_fill_(bucket_mask, mask_value)
-        del bucket_mask
-
-        energy = apply_input_mask(
-            b,
-            energy=energy,
-            input_mask=input_mask,
-            backward=self.look_back,
-            window_size=self.window_size,
-            num_windows=num_windows,
-            mask_value=mask_value,
-            autopad=self.autopad,
-        )
-        return energy
-
-    def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor:
-        """Computes windowed attention."""
-        q, k, v, shape = self._to_embeddings(x)
-        d = q.shape[-1]
-
-        if self.autopad:
-            orig_t = q.shape[1]
-            q, k, v = map(
-                lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
-            )
-
-        b, n, d = q.shape
-
-        if not n % self.window_size:
-            RuntimeError(
-                f"Sequence length {n} must be divisable with window size {self.window_size}"
-            )
-
-        num_windows = n // self.window_size
-
-        # Compute buckets
-        b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
-
-        # Compute the attention.
-        energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
-        energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows)
-
-        attn = F.softmax(energy, dim=-1)
-        attn = self.dropout(attn)
-
-        out = einsum("b h i j, b h j d -> b h i d", attn, bv)
-        out = out.reshape(-1, n, d)
-        if self.autopad:
-            out = out[:, :orig_t, :]
-            n = orig_t
-
-        b = x.shape[0]
-        out = out.reshape(*shape)
-        out = out.reshape(b, n, -1)
-        out = self.fc(out)
-
-        return out
-
-
-def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
-    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
-    return q, k
-
-
-def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor:
-    shape = list(t.shape)
-    dims = len(t.shape)
-    dim = (dim + dims) % dims
-    shape[dim : dim + 1] = split_dims
-    return t.reshape(shape)
-
-
-def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
-    """Merge dimensions."""
-    shape = list(tensor.shape)
-    arr_slice = slice(ind_from, ind_to + 1)
-    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
-    return tensor.reshape(*shape)
-
-
-def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
-    """Expand tensors dimensions."""
-    if unsqueeze:
-        t = t.unsqueeze(dim)
-    expand_shape = [-1] * len(t.shape)
-    expand_shape[dim] = k
-    return t.expand(*expand_shape)
-
-
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
-    """Apply windowing."""
-    n = x.shape[1]
-    dims = (len(x.shape) - dim) * (0, 0)
-    x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
-    tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
-    return torch.cat(tensors, dim=dim)
-
-
-def apply_input_mask(
-    b: int,
-    energy: Tensor,
-    input_mask: Tensor,
-    backward: int,
-    window_size: int,
-    num_windows: int,
-    mask_value: Tensor,
-    autopad: bool,
-) -> Tensor:
-    """Applies input mask to energy tensor."""
-    h = b // input_mask.shape[0]
-    if autopad:
-        input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False)
-    input_mask = input_mask.reshape(-1, num_windows, window_size)
-    mq = mk = input_mask
-    mk = look_around(mk, pad_value=False, backward=backward)
-    mask = mq[:, :, :, None] * mk[:, :, None, :]
-    mask = merge_dims(0, 1, expand_dim(mask, 1, h))
-    energy.masked_fill_(~mask, mask_value)
-    del mask
-    return energy
-
-
-def pad_to_multiple(
-    tensor: Tensor, multiple: int, dim: int = -1, value: int = 0
-) -> Tensor:
-    seqlen = tensor.shape[dim]
-    m = seqlen / multiple
-    if m.is_integer():
-        return tensor
-    remainder = math.ceil(m) * multiple - seqlen
-    pad_offset = (0,) * (-1 - dim) * 2
-    return F.pad(tensor, (*pad_offset, 0, remainder), value=value)
-- 
cgit v1.2.3-70-g09d2