diff options
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 80 |
1 files changed, 27 insertions, 53 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index fca260d..85f513e 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,5 +1,5 @@ """Implementes the attention module for the transformer.""" -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -8,7 +8,6 @@ from torch import Tensor, einsum, nn from text_recognizer.networks.transformer.embeddings.rotary import ( RotaryEmbedding, - rotate_half, ) @@ -22,100 +21,75 @@ class Attention(nn.Module): causal: bool = False, dim_head: int = 64, dropout_rate: float = 0.0, - rotary_embedding: Optional[RotaryEmbedding] = None, ) -> None: super().__init__() - self.dim = dim + self.scale = self.dim**-0.5 self.num_heads = num_heads - self.causal = causal self.dim_head = dim_head + + self.causal = causal self.dropout_rate = dropout_rate - self.rotary_embedding = rotary_embedding - self.scale = self.dim**-0.5 - inner_dim = self.num_heads * self.dim_head + # Single key/value head + k_dim = dim_head + v_dim = dim_head + + out_dim = self.num_heads * self.dim_head - self.to_q = nn.Linear(self.dim, inner_dim, bias=False) - self.to_k = nn.Linear(self.dim, inner_dim, bias=False) - self.to_v = nn.Linear(self.dim, inner_dim, bias=False) + self.to_q = nn.Linear(self.dim, out_dim, bias=False) + self.to_k = nn.Linear(self.dim, k_dim, bias=False) + self.to_v = nn.Linear(self.dim, v_dim, bias=False) self.dropout = nn.Dropout(p=self.dropout_rate) # Feedforward - self.fc = nn.Linear(inner_dim, self.dim) + self.fc = nn.Linear(out_dim, self.dim) def forward( self, x: Tensor, context: Optional[Tensor] = None, - input_mask: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + rotary_embedding: Optional[RotaryEmbedding] = None, ) -> Tensor: """Computes the attention.""" - b, n, _, device = *x.shape, x.device + b, device = x.shape[0], x.device q = self.to_q(x) + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) k = self.to_k(context) if context is not None else self.to_k(x) v = self.to_v(context) if context is not None else self.to_v(x) - q, k, v = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) - ) - if self.rotary_embedding is not None: - embedding = self.rotary_embedding(q) - q, k, v = _apply_rotary_emb(q, k, v, embedding[None, ...]) + if rotary_embedding is not None: + q, k, v = map(lambda t: rotary_embedding.rotate(t), (q, k, v)) - energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + energy = einsum("b h i d, b j d -> b h i j", q, k) * self.scale mask_value = -torch.finfo(energy.dtype).max - energy = apply_input_mask( - b, n, k, energy, input_mask, context, context_mask, mask_value, device - ) + energy = apply_input_mask(b, k, energy, mask, mask_value, device) if self.causal: - energy = apply_causal_mask(energy, input_mask, mask_value, device) + energy = apply_causal_mask(energy, mask, mask_value, device) attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) - out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = einsum("b h i j, b j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.fc(out) return out -def _apply_rotary_emb( - q: Tensor, k: Tensor, v: Tensor, freqs: Tensor -) -> Tuple[Tensor, Tensor, Tensor]: - q, k, v = map( - lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k, v) - ) - return q, k, v - - def apply_input_mask( b: int, - n: int, k: Tensor, energy: Tensor, - input_mask: Optional[Tensor], - context: Optional[Tensor], - context_mask: Optional[Tensor], + mask: Optional[Tensor], mask_value: Tensor, device: str, ) -> Tensor: """Applies an input mask.""" - if any(x is not None for x in (input_mask, context_mask)): - q_mask = ( - input_mask - if input_mask is not None - else torch.ones((b, n), device=device).bool() - ) - k_mask = q_mask if context is None else context_mask - k_mask = ( - torch.ones((b, k.shape[-2]), device=device).bool() - if k_mask is None - else k_mask - ) - q_mask = rearrange(q_mask, "b i -> b () i ()") + if mask is not None: + k_mask = torch.ones((b, k.shape[-2]), device=device).bool() + q_mask = rearrange(mask, "b i -> b () i ()") k_mask = rearrange(k_mask, "b j -> b () () j") input_mask = q_mask * k_mask |