summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/local_attention.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:29 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:29 +0100
commit2cb2c5b38f0711267fecfe9c5e10940f4b4f79fc (patch)
treee6e3dfe027a365e2ad5a14c373cad5f2aa77b3ac /text_recognizer/networks/transformer/local_attention.py
parent91db5e23f86ec0b829aebef6eef642bcf63da53b (diff)
Remove local attention
Diffstat (limited to 'text_recognizer/networks/transformer/local_attention.py')
-rw-r--r--text_recognizer/networks/transformer/local_attention.py233
1 files changed, 0 insertions, 233 deletions
diff --git a/text_recognizer/networks/transformer/local_attention.py b/text_recognizer/networks/transformer/local_attention.py
deleted file mode 100644
index a008bab..0000000
--- a/text_recognizer/networks/transformer/local_attention.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""Local attention module.
-
-Also stolen from lucidrains from here:
-https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py
-"""
-from functools import reduce
-import math
-from operator import mul
-from typing import List, Optional, Tuple
-
-import attr
-import torch
-from torch import einsum
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-
-from text_recognizer.networks.transformer.embeddings.rotary import (
- RotaryEmbedding,
- rotate_half,
-)
-
-
-@attr.s(eq=False)
-class LocalAttention(nn.Module):
- """Local windowed attention."""
-
- dim: int = attr.ib()
- num_heads: int = attr.ib()
- dim_head: int = attr.ib(default=64)
- window_size: int = attr.ib(default=128)
- look_back: int = attr.ib(default=1)
- dropout_rate: float = attr.ib(default=0.0)
- autopad: bool = attr.ib(default=False)
- rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None)
-
- def __attrs_pre_init__(self) -> None:
- """Pre init constructor."""
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- """Post init constructor."""
- self.scale = self.dim ** -0.5
- inner_dim = self.num_heads * self.dim_head
-
- self.to_qkv = nn.Linear(self.dim, 3 * inner_dim, bias=False)
- self.dropout = nn.Dropout(p=self.dropout_rate)
-
- # Feedforward
- self.fc = nn.Linear(inner_dim, self.dim)
-
- def _to_embeddings(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
- """Convert input into query, key, and value."""
-
- def _split_heads(t: Tensor) -> Tensor:
- return _reshape_dim(t, -1, (-1, self.dim_head)).transpose(1, 2).contiguous()
-
- def _merge_into_batch(t: Tensor) -> Tensor:
- return t.reshape(-1, *t.shape[-2:])
-
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(_split_heads, qkv)
- shape = q.shape
-
- q, k, v = map(_merge_into_batch, (q, k, v))
-
- if self.rotary_embedding is not None:
- embedding = self.rotary_embedding(q)
- q, k = _apply_rotary_emb(q, k, embedding)
- return q, k, v, shape
-
- def _create_buckets(
- self, q: Tensor, k: Tensor, v: Tensor, n: int, b: int, num_windows: int
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- b_n = torch.arange(n).type_as(q).reshape(1, num_windows, self.window_size)
- bq, bk, bv = map(
- lambda t: t.reshape(b, num_windows, self.window_size, -1), (q, k, v),
- )
-
- bk = look_around(bk, backward=self.look_back)
- bv = look_around(bv, backward=self.look_back)
- bq_k = look_around(b_n, backward=self.look_back)
- return b_n, bq, bk, bv, bq_k
-
- def _apply_masks(
- self,
- b: int,
- energy: Tensor,
- b_n: Tensor,
- bq_k: Tensor,
- input_mask: Tensor,
- num_windows: int,
- ) -> Tensor:
- mask_value = -torch.finfo(energy.dtype).max
-
- # Causal mask.
- causal_mask = b_n[:, :, :, None] < bq_k[:, :, None, :]
- energy = energy.masked_fill_(causal_mask, mask_value)
- del causal_mask
-
- bucket_mask = bq_k[:, :, None, :] == -1
- energy.masked_fill_(bucket_mask, mask_value)
- del bucket_mask
-
- energy = apply_input_mask(
- b,
- energy=energy,
- input_mask=input_mask,
- backward=self.look_back,
- window_size=self.window_size,
- num_windows=num_windows,
- mask_value=mask_value,
- autopad=self.autopad,
- )
- return energy
-
- def forward(self, x: Tensor, input_mask: Optional[Tensor] = None,) -> Tensor:
- """Computes windowed attention."""
- q, k, v, shape = self._to_embeddings(x)
- d = q.shape[-1]
-
- if self.autopad:
- orig_t = q.shape[1]
- q, k, v = map(
- lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
- )
-
- b, n, d = q.shape
-
- if not n % self.window_size:
- RuntimeError(
- f"Sequence length {n} must be divisable with window size {self.window_size}"
- )
-
- num_windows = n // self.window_size
-
- # Compute buckets
- b_n, bq, bk, bv, bq_k = self._create_buckets(q, k, v, n, b, num_windows)
-
- # Compute the attention.
- energy = einsum("b h i d, b h j d -> b h i j", bq, bk) * self.scale
- energy = self._apply_masks(b, energy, b_n, bq_k, input_mask, num_windows)
-
- attn = F.softmax(energy, dim=-1)
- attn = self.dropout(attn)
-
- out = einsum("b h i j, b h j d -> b h i d", attn, bv)
- out = out.reshape(-1, n, d)
- if self.autopad:
- out = out[:, :orig_t, :]
- n = orig_t
-
- b = x.shape[0]
- out = out.reshape(*shape)
- out = out.reshape(b, n, -1)
- out = self.fc(out)
-
- return out
-
-
-def _apply_rotary_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
- q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
- return q, k
-
-
-def _reshape_dim(t: Tensor, dim: int, split_dims: Tuple[int, int]) -> Tensor:
- shape = list(t.shape)
- dims = len(t.shape)
- dim = (dim + dims) % dims
- shape[dim : dim + 1] = split_dims
- return t.reshape(shape)
-
-
-def merge_dims(ind_from: int, ind_to: int, tensor: Tensor) -> Tensor:
- """Merge dimensions."""
- shape = list(tensor.shape)
- arr_slice = slice(ind_from, ind_to + 1)
- shape[arr_slice] = [reduce(mul, shape[arr_slice])]
- return tensor.reshape(*shape)
-
-
-def expand_dim(t: Tensor, dim: int, k: int, unsqueeze: bool = True) -> Tensor:
- """Expand tensors dimensions."""
- if unsqueeze:
- t = t.unsqueeze(dim)
- expand_shape = [-1] * len(t.shape)
- expand_shape[dim] = k
- return t.expand(*expand_shape)
-
-
-def look_around(x: Tensor, backward: int, pad_value: int = -1, dim: int = 2) -> Tensor:
- """Apply windowing."""
- n = x.shape[1]
- dims = (len(x.shape) - dim) * (0, 0)
- x_pad = F.pad(x, (*dims, backward, 0), value=pad_value)
- tensors = [x_pad[:, ind : (ind + n), ...] for ind in range(backward + 1)]
- return torch.cat(tensors, dim=dim)
-
-
-def apply_input_mask(
- b: int,
- energy: Tensor,
- input_mask: Tensor,
- backward: int,
- window_size: int,
- num_windows: int,
- mask_value: Tensor,
- autopad: bool,
-) -> Tensor:
- """Applies input mask to energy tensor."""
- h = b // input_mask.shape[0]
- if autopad:
- input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False)
- input_mask = input_mask.reshape(-1, num_windows, window_size)
- mq = mk = input_mask
- mk = look_around(mk, pad_value=False, backward=backward)
- mask = mq[:, :, :, None] * mk[:, :, None, :]
- mask = merge_dims(0, 1, expand_dim(mask, 1, h))
- energy.masked_fill_(~mask, mask_value)
- del mask
- return energy
-
-
-def pad_to_multiple(
- tensor: Tensor, multiple: int, dim: int = -1, value: int = 0
-) -> Tensor:
- seqlen = tensor.shape[dim]
- m = seqlen / multiple
- if m.is_integer():
- return tensor
- remainder = math.ceil(m) * multiple - seqlen
- pad_offset = (0,) * (-1 - dim) * 2
- return F.pad(tensor, (*pad_offset, 0, remainder), value=value)