diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-01 23:53:50 +0200 | 
| commit | 58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (patch) | |
| tree | c89c1b1a4cc1a499900f2700ab09e8535e2cfe99 /text_recognizer/networks/transformer/attention.py | |
| parent | 7ae1f8f9654dcea0a9a22310ac0665a5d3202f0f (diff) | |
Working on new attention module
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
| -rw-r--r-- | text_recognizer/networks/transformer/attention.py | 119 | 
1 files changed, 49 insertions, 70 deletions
| diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index ac75d2f..e1324af 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,94 +1,73 @@  """Implementes the attention module for the transformer."""  from typing import Optional, Tuple -from einops import rearrange +from einops.layers.torch import Rearrange  import numpy as np  import torch  from torch import nn  from torch import Tensor +import torch.nn.functional as F +from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb -class MultiHeadAttention(nn.Module): -    """Implementation of multihead attention.""" +class Attention(nn.Module):      def __init__( -        self, hidden_dim: int, num_heads: int = 8, dropout_rate: float = 0.0 +        self, +        dim: int, +        num_heads: int, +        dim_head: int = 64, +        dropout_rate: float = 0.0, +        causal: bool = False,      ) -> None: -        super().__init__() -        self.hidden_dim = hidden_dim +        self.scale = dim ** -0.5          self.num_heads = num_heads -        self.fc_q = nn.Linear( -            in_features=hidden_dim, out_features=hidden_dim, bias=False -        ) -        self.fc_k = nn.Linear( -            in_features=hidden_dim, out_features=hidden_dim, bias=False -        ) -        self.fc_v = nn.Linear( -            in_features=hidden_dim, out_features=hidden_dim, bias=False -        ) -        self.fc_out = nn.Linear(in_features=hidden_dim, out_features=hidden_dim) - -        self._init_weights() +        self.causal = causal +        inner_dim = dim * dim_head -        self.dropout = nn.Dropout(p=dropout_rate) - -    def _init_weights(self) -> None: -        nn.init.normal_( -            self.fc_q.weight, -            mean=0, -            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), -        ) -        nn.init.normal_( -            self.fc_k.weight, -            mean=0, -            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), +        # Attnetion +        self.qkv_fn = nn.Sequential( +            nn.Linear(dim, 3 * inner_dim, bias=False), +            Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads),          ) -        nn.init.normal_( -            self.fc_v.weight, -            mean=0, -            std=np.sqrt(self.hidden_dim + int(self.hidden_dim / self.num_heads)), -        ) -        nn.init.xavier_normal_(self.fc_out.weight) +        self.dropout = nn.Dropout(dropout_rate) +        self.attn_fn = F.softmax -    @staticmethod -    def scaled_dot_product_attention( -        query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None -    ) -> Tensor: -        """Calculates the scaled dot product attention.""" +        # Feedforward +        self.proj = nn.Linear(inner_dim, dim) -        # Compute the energy. -        energy = torch.einsum("bhlk,bhtk->bhlt", [query, key]) / np.sqrt( -            query.shape[-1] -        ) - -        # If we have a mask for padding some inputs. -        if mask is not None: -            energy = energy.masked_fill(mask == 0, -np.inf) - -        # Compute the attention from the energy. -        attention = torch.softmax(energy, dim=3) +    @staticmethod +    def _apply_rotary_emb( +        q: Tensor, k: Tensor, rotary_pos_emb: Tensor +    ) -> Tuple[Tensor, Tensor]: +        l = rotary_pos_emb.shape[-1] +        (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) +        ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) +        q = torch.cat((ql, qr), dim=-1) +        k = torch.cat((kl, kr), dim=-1) +        return q, k -        out = torch.einsum("bhlt,bhtv->bhlv", [attention, value]) -        out = rearrange(out, "b head l v -> b l (head v)") -        return out, attention +    def _cross_attention(self) -> Tensor: +        pass      def forward( -        self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None +        self, +        x: Tensor, +        context: Optional[Tensor], +        mask: Optional[Tensor], +        context_mask: Optional[Tensor], +        rotary_pos_emb: Optional[Tensor] = None,      ) -> Tuple[Tensor, Tensor]: -        """Forward pass for computing the multihead attention.""" -        # Get the query, key, and value tensor. -        query = rearrange( -            self.fc_q(query), "b l (head k) -> b head l k", head=self.num_heads -        ) -        key = rearrange( -            self.fc_k(key), "b t (head k) -> b head t k", head=self.num_heads -        ) -        value = rearrange( -            self.fc_v(value), "b t (head v) -> b head t v", head=self.num_heads +        q, k, v = self.qkv_fn(x) +        q, k = ( +            self._apply_rotary_emb(q, k, rotary_pos_emb) +            if rotary_pos_emb is not None +            else q, +            k,          ) -        out, attention = self.scaled_dot_product_attention(query, key, value, mask) +        if any(x is not None for x in (mask, context_mask)): +            pass -        out = self.fc_out(out) -        out = self.dropout(out) -        return out, attention +        # Compute the attention +        energy = (q @ k.transpose(-2, -1)) * self.scale |