diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:40:46 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-08 08:40:46 +0200 |
commit | 8ae1b802bb7d7c63cf758e44269e97a4c0788b65 (patch) | |
tree | fdabe515ace3e223c45dfbe72e3a4498abefbce3 /text_recognizer/networks/conformer | |
parent | d194ab9fd3eb64d715c83066a7a690c6c8834dde (diff) |
Add efficient self attention
Diffstat (limited to 'text_recognizer/networks/conformer')
-rw-r--r-- | text_recognizer/networks/conformer/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/networks/conformer/attention.py | 49 |
2 files changed, 50 insertions, 0 deletions
diff --git a/text_recognizer/networks/conformer/__init__.py b/text_recognizer/networks/conformer/__init__.py index 1886f85..8951481 100644 --- a/text_recognizer/networks/conformer/__init__.py +++ b/text_recognizer/networks/conformer/__init__.py @@ -4,3 +4,4 @@ from text_recognizer.networks.conformer.glu import GLU from text_recognizer.networks.conformer.conformer import Conformer from text_recognizer.networks.conformer.conv import ConformerConv from text_recognizer.networks.conformer.subsampler import Subsampler +from text_recognizer.networks.conformer.attention import Attention diff --git a/text_recognizer/networks/conformer/attention.py b/text_recognizer/networks/conformer/attention.py new file mode 100644 index 0000000..e56e572 --- /dev/null +++ b/text_recognizer/networks/conformer/attention.py @@ -0,0 +1,49 @@ +"""Efficient self attention.""" +from einops import rearrange +import torch +import torch.nn.functional as F +from torch import einsum, nn, Tensor + + +class LayerNorm(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x: Tensor) -> Tensor: + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +class SwiGLU(nn.Module): + def forward(self, x: Tensor) -> Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +class Attention(nn.Module): + def __init__( + self, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4 + ) -> None: + super().__init__() + self.norm = LayerNorm(dim) + attn_inner_dim = heads * dim_head + ff_inner_dim = mult * dim + self.heads = heads + self.scale = dim_head ** -0.5 + self.fused_dims = (attn_inner_dim, dim_head, dim_head, (2 * ff_inner_dim)) + self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) + self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) + self.ff_out = nn.Sequential(SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False)) + + def forward(self, x: Tensor) -> Tensor: + h = self.heads + x = self.norm(x) + q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + q = rearrange(q, "b n (h d) -> b h n d", h=h) + q = q * self.scale + sim = einsum("b h i d, b j d -> b h i j", q, k) + attn = sim.softmax(dim=-1) + out = einsum("b h i j, b j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.attn_out(out) + self.ff_out(ff) |