"""Implementes the attention module for the transformer.""" from typing import Optional, Tuple from einops.layers.torch import Rearrange import numpy as np import torch from torch import nn from torch import Tensor import torch.nn.functional as F from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb class Attention(nn.Module): def __init__( self, dim: int, num_heads: int, dim_head: int = 64, dropout_rate: float = 0.0, causal: bool = False, ) -> None: self.scale = dim ** -0.5 self.num_heads = num_heads self.causal = causal inner_dim = dim * dim_head # Attnetion self.qkv_fn = nn.Sequential( nn.Linear(dim, 3 * inner_dim, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.num_heads), ) self.dropout = nn.Dropout(dropout_rate) self.attn_fn = F.softmax # Feedforward self.proj = nn.Linear(inner_dim, dim) @staticmethod def _apply_rotary_emb( q: Tensor, k: Tensor, rotary_pos_emb: Tensor ) -> Tuple[Tensor, Tensor]: l = rotary_pos_emb.shape[-1] (ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k)) ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb) q = torch.cat((ql, qr), dim=-1) k = torch.cat((kl, kr), dim=-1) return q, k def _cross_attention(self) -> Tensor: pass def forward( self, x: Tensor, context: Optional[Tensor], mask: Optional[Tensor], context_mask: Optional[Tensor], rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: q, k, v = self.qkv_fn(x) q, k = ( self._apply_rotary_emb(q, k, rotary_pos_emb) if rotary_pos_emb is not None else q, k, ) if any(x is not None for x in (mask, context_mask)): pass # Compute the attention energy = (q @ k.transpose(-2, -1)) * self.scale