"""Implementes the attention module for the transformer.""" from typing import Optional, Tuple import torch import torch.nn.functional as F from einops import rearrange from torch import Tensor, einsum, nn from text_recognizer.networks.transformer.embeddings.rotary import ( RotaryEmbedding, rotate_half, ) class Attention(nn.Module): """Standard attention.""" def __init__( self, dim: int, num_heads: int, causal: bool = False, dim_head: int = 64, dropout_rate: float = 0.0, rotary_embedding: Optional[RotaryEmbedding] = None, ) -> None: super().__init__() self.dim = dim self.num_heads = num_heads self.causal = causal self.dim_head = dim_head self.dropout_rate = dropout_rate self.rotary_embedding = rotary_embedding self.scale = self.dim**-0.5 inner_dim = self.num_heads * self.dim_head self.to_q = nn.Linear(self.dim, inner_dim, bias=False) self.to_k = nn.Linear(self.dim, inner_dim, bias=False) self.to_v = nn.Linear(self.dim, inner_dim, bias=False) self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward self.fc = nn.Linear(inner_dim, self.dim) def forward( self, x: Tensor, context: Optional[Tensor] = None, input_mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, ) -> Tensor: """Computes the attention.""" b, n, _, device = *x.shape, x.device q = self.to_q(x) k = self.to_k(context) if context is not None else self.to_k(x) v = self.to_v(context) if context is not None else self.to_v(x) q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) if self.rotary_embedding is not None: embedding = self.rotary_embedding(q) q, k, v = _apply_rotary_emb(q, k, v, embedding[None, ...]) energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale mask_value = -torch.finfo(energy.dtype).max energy = apply_input_mask( b, n, k, energy, input_mask, context, context_mask, mask_value, device ) if self.causal: energy = apply_causal_mask(energy, input_mask, mask_value, device) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) return out def _apply_rotary_emb( q: Tensor, k: Tensor, v: Tensor, freqs: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: q, k, v = map( lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k, v) ) return q, k, v def apply_input_mask( b: int, n: int, k: Tensor, energy: Tensor, input_mask: Optional[Tensor], context: Optional[Tensor], context_mask: Optional[Tensor], mask_value: Tensor, device: str, ) -> Tensor: """Applies an input mask.""" if any(x is not None for x in (input_mask, context_mask)): q_mask = ( input_mask if input_mask is not None else torch.ones((b, n), device=device).bool() ) k_mask = q_mask if context is None else context_mask k_mask = ( torch.ones((b, k.shape[-2]), device=device).bool() if k_mask is None else k_mask ) q_mask = rearrange(q_mask, "b i -> b () i ()") k_mask = rearrange(k_mask, "b j -> b () () j") input_mask = q_mask * k_mask energy = energy.masked_fill_(~input_mask, mask_value) return energy def apply_causal_mask( energy: Tensor, mask: Tensor, mask_value: Tensor, device: str ) -> Tensor: """Applies a causal mask to the energy tensor.""" i, j = energy.shape[-2:] r = torch.arange(i, device=device) mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) energy.masked_fill_(mask, mask_value) return energy