From 53450493e0a13d835fd1d2457c49a9d60bee0e18 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 May 2021 23:11:44 +0200 Subject: Nyströmer implemented but not tested MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- text_recognizer/networks/transformer/norm.py | 17 +- text_recognizer/networks/transformer/nystromer.py | 0 .../networks/transformer/nystromer/__init__.py | 0 .../networks/transformer/nystromer/attention.py | 181 +++++++++++++++++++++ .../networks/transformer/nystromer/nystromer.py | 57 +++++++ 5 files changed, 252 insertions(+), 3 deletions(-) delete mode 100644 text_recognizer/networks/transformer/nystromer.py create mode 100644 text_recognizer/networks/transformer/nystromer/__init__.py create mode 100644 text_recognizer/networks/transformer/nystromer/attention.py create mode 100644 text_recognizer/networks/transformer/nystromer/nystromer.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/transformer/norm.py b/text_recognizer/networks/transformer/norm.py index 9160876..58c8770 100644 --- a/text_recognizer/networks/transformer/norm.py +++ b/text_recognizer/networks/transformer/norm.py @@ -4,7 +4,7 @@ Copied from lucidrains: https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py """ -from typing import Callable, Dict +from typing import Callable, Dict, Type import torch from torch import nn @@ -31,5 +31,16 @@ class ScaleNorm(nn.Module): def forward(self, x: Tensor) -> Tensor: norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) self.g - + return x / norm.clamp(min=self.eps) * self.g + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Type[nn.Module]) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x: Tensor, **kwargs: Dict) -> Tensor: + """Norm tensor.""" + x = self.norm(x) + return self.fn(x, **kwargs) diff --git a/text_recognizer/networks/transformer/nystromer.py b/text_recognizer/networks/transformer/nystromer.py deleted file mode 100644 index e69de29..0000000 diff --git a/text_recognizer/networks/transformer/nystromer/__init__.py b/text_recognizer/networks/transformer/nystromer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py new file mode 100644 index 0000000..c2871fb --- /dev/null +++ b/text_recognizer/networks/transformer/nystromer/attention.py @@ -0,0 +1,181 @@ +"""Nyströmer encoder. + +Efficient attention module that reduces the complexity of the attention module from +O(n**2) to O(n). The attention matrix is assumed low rank and thus the information +can be represented by a smaller matrix. + +Stolen from: + https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py + +""" +from math import ceil +from typing import Optional, Tuple, Union + +from einops import rearrange, reduce +import torch +from torch import einsum, nn, Tensor +from torch.nn import functional as F + + +def moore_penrose_inverse(x: Tensor, iters: int = 6) -> Tensor: + """Moore-Penrose pseudoinverse.""" + x_abs = torch.abs(x) + col = x_abs.sum(dim=-1) + row = x_abs.sum(dim=-2) + z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) + + I = torch.eye(x.shape[-1], device=x.device) + I = rearrange(I, "i j -> () i j") + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) + return z + + +class NystromAttention(nn.Module): + def __init__( + self, + dim: int, + dim_head: int = 64, + num_heads: int = 8, + num_landmarks: int = 256, + inverse_iter: int = 6, + residual: bool = True, + residual_conv_kernel: int = 13, + eps: float = 1.0e-8, + dropout_rate: float = 0.0, + ): + super().__init__() + self.residual = None + self.eps = eps + self.num_heads = num_heads + inner_dim = self.num_heads * dim_head + self.num_landmarks = num_landmarks + self.inverse_iter = inverse_iter + self.scale = dim_head ** -0.5 + + self.qkv_fn = nn.Linear(dim, 3 * inner_dim, bias=False) + self.fc_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate)) + + if residual: + self.residual = nn.Conv2d( + in_channels=num_heads, + out_channels=num_heads, + kernel_size=(residual_conv_kernel, 1), + padding=(residual_conv_kernel // 2, 0), + groups=num_heads, + bias=False, + ) + + @staticmethod + def _pad_sequence( + x: Tensor, mask: Optional[Tensor], n: int, m: int + ) -> Tuple[Tensor, Tensor]: + """Pad sequence.""" + padding = m - (n % m) + x = F.pad(x, (0, 0, padding, 0), value=0) + mask = F.pad(mask, (padding, 0), value=False) if mask is not None else mask + return x, mask + + def _compute_landmarks( + self, q: Tensor, k: Tensor, mask: Optional[Tensor], n: int, m: int + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Compute landmarks of the attention matrix.""" + divisor = ceil(n / m) + landmark_einops_eq = "... (n l) d -> ... n d" + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=divisor) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=divisor) + + mask_landmarks = None + if mask is not None: + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=divisor) + divisor = mask_landmarks_sum[..., None] + self.eps + mask_landmarks = mask_landmarks_sum > 0 + + q_landmarks /= divisor + k_landmarks /= divisor + + return q_landmarks, k_landmarks, mask_landmarks + + @staticmethod + def _compute_similarities( + q: Tensor, + k: Tensor, + q_landmarks: Tensor, + k_landmarks: Tensor, + mask: Optional[Tensor], + mask_landmarks: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + einops_eq = "... i d, ... j d -> ... i j" + sim1 = einsum(einops_eq, q, k_landmarks) + sim2 = einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = einsum(einops_eq, q_landmarks, k) + + if mask is not None and mask_landmarks is not None: + mask_value = -torch.finfo(q.type).max + sim1.masked_fill_( + ~(mask[..., None] * mask_landmarks[..., None, :]), mask_value + ) + sim2.masked_fill_( + ~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value + ) + sim3.masked_fill_( + ~(mask_landmarks[..., None] * mask[..., None, :]), mask_value + ) + + return sim1, sim2, sim3 + + def _nystrom_attention( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor], + n: int, + m: int, + return_attn: bool, + ) -> Tuple[Tensor, Optional[Tensor]]: + q_landmarks, k_landmarks, mask_landmarks = self._compute_landmarks( + q, k, mask, n, m + ) + sim1, sim2, sim3 = self._compute_similarities( + q, k, q_landmarks, k_landmarks, mask, mask_landmarks + ) + + # Compute attention + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_inverse(attn2, self.inverse_iter) + + out = (attn1 @ attn2_inv) @ (attn3 @ v) + + if return_attn: + return out, attn1 @ attn2_inv @ attn3 + return out, None + + def forward( + self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute the Nystrom attention.""" + _, n, _, h, m = x.shape, self.num_heads + if n % m != 0: + x, mask = self._pad_sequence(x, mask, n, m) + + q, k, v = self.qkv_fn(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + q *= self.scale + + out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn) + + # Add depth-wise convolutional residual of values + if self.residual is not None: + out += self.residual(out) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + out = self.fc_out(out) + out = out[:, -n:] + + if return_attn: + return out, attn + return out diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py new file mode 100644 index 0000000..0283d69 --- /dev/null +++ b/text_recognizer/networks/transformer/nystromer/nystromer.py @@ -0,0 +1,57 @@ +"""Nyströmer encoder. + +Stolen from: + https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py + +""" +from typing import Optional + +from torch import nn, Tensor + +from text_recognizer.networks.transformer.mlp import FeedForward +from text_recognizer.networks.transformer.norm import PreNorm +from text_recognizer.networks.transformer.nystromer.attention import NystromAttention + + +class Nystromer(nn.Module): + def __init__( + self, + *, + dim: int, + depth: int, + dim_head: int = 64, + num_heads: int = 8, + num_landmarks: int = 256, + inverse_iter: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + dropout_rate: float = 0.0, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + [ + PreNorm( + dim, + NystromAttention( + dim=dim, + dim_head=dim_head, + num_heads=num_heads, + num_landmarks=num_landmarks, + inverse_iter=inverse_iter, + residual=residual, + residual_conv_kernel=residual_conv_kernel, + dropout_rate=dropout_rate, + ), + ), + PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), + ] + ] + for _ in range(depth) + ) + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + for attn, ff in self.layers: + x = attn(x, mask=mask) + x + x = ff(x) + x + return x -- cgit v1.2.3-70-g09d2