diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 |
commit | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch) | |
tree | c89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks/transformer/attention.py | |
parent | 7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff) |
Working on new attention module
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 119 |
1 files changed, 49 insertions, 70 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index ac75d2f..e1324af 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,94 +1,73 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple -from einops import rearrange +from einops.layers.torch import Rearrange import numpy as np import torch from torch import nn from torch import Tensor +import torch.nn.functional as F +from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb -class MultiHeadAttention(nn.Module): - """Implementation of multihead attention.""" +class Attention(nn.Module): def __init__( - self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 + self, + dim: int, + num_heads: int, + dim_head: int = 64, + dropout_rate: float = 0.0, + causal: bool = False, ) -> None: - super().__init__() - self.hidden_dim = hidden_dim + self.scale = dim ** -0.5 self.num_heads = num_heads - self.fc_q = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_k = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_v = nn.Linear( - in_features=hidden_dim, out_features=hidden_dim, bias=False - ) - self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - - self._init_weights() + self.causal = causal + inner_dim = dim * dim_head - self.dropout = nn.Dropout(p=dropout_rate) - - def _init_weights(self) -> None: - nn.init.normal_( - self.fc_q.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.normal_( - self.fc_k.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), + # Attnetion + self.qkv_fn = nn.Sequential( + nn.Linear(dim, 3 * inner_dim, bias=False), + Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), ) - nn.init.normal_( - self.fc_v.weight, - mean=0, - std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), - ) - nn.init.xavier_normal_(self.fc_out.weight) + self.dropout = nn.Dropout(dropout_rate) + self.attn_fn = F.softmax - @staticmethod - def scaled_dot_product_attention( - query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None - ) -> Tensor: - """Calculates the scaled dot product attention.""" + # Feedforward + self.proj = nn.Linear(inner_dim, dim) - # Compute the energy. - energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( - query.shape[-1] - ) - - # If we have a mask for padding some inputs. - if mask is not None: - energy = energy.masked_fill(mask == 0, -np.inf) - - # Compute the attention from the energy. - attention = torch.softmax(energy, dim=3) + @staticmethod + def _apply_rotary_emb( + q: Tensor, k: Tensor, rotary_pos_emb: Tensor + ) -> Tuple[Tensor, Tensor]: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) + ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) + q = torch.cat((ql, qr), dim=-1) + k = torch.cat((kl, kr), dim=-1) + return q, k - out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) - out = rearrange(out, "b head l v -> b l (head v)") - return out, attention + def _cross_attention(self) -> Tensor: + pass def forward( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + self, + x: Tensor, + context: Optional[Tensor], + mask: Optional[Tensor], + context_mask: Optional[Tensor], + rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - """Forward pass for computing the multihead attention.""" - # Get the query, key, and value tensor. - query = rearrange( - self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads - ) - key = rearrange( - self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads - ) - value = rearrange( - self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads + q, k, v = self.qkv_fn(x) + q, k = ( + self._apply_rotary_emb(q, k, rotary_pos_emb) + if rotary_pos_emb is not None + else q, + k, ) - out, attention = self.scaled_dot_product_attention(query, key, value, mask) + if any(x is not None for x in (mask, context_mask)): + pass - out = self.fc_out(out) - out = self.dropout(out) - return out, attention + # Compute the attention + energy = (q @ k.transpose(-2, -1)) * self.scale |