diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:14 +0200 |
commit | 49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch) | |
tree | 20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/networks/transformer/attention.py | |
parent | 0421daf6bd97596703f426ba61c401599b538eeb (diff) |
Rename and add flash atten
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 109 |
1 files changed, 0 insertions, 109 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py deleted file mode 100644 index 85f513e..0000000 --- a/text_recognizer/networks/transformer/attention.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Implementes the attention module for the transformer.""" -from typing import Optional - -import torch -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor, einsum, nn - -from text_recognizer.networks.transformer.embeddings.rotary import ( - RotaryEmbedding, -) - - -class Attention(nn.Module): - """Standard attention.""" - - def __init__( - self, - dim: int, - num_heads: int, - causal: bool = False, - dim_head: int = 64, - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - self.dim = dim - self.scale = self.dim**-0.5 - self.num_heads = num_heads - self.dim_head = dim_head - - self.causal = causal - self.dropout_rate = dropout_rate - - # Single key/value head - k_dim = dim_head - v_dim = dim_head - - out_dim = self.num_heads * self.dim_head - - self.to_q = nn.Linear(self.dim, out_dim, bias=False) - self.to_k = nn.Linear(self.dim, k_dim, bias=False) - self.to_v = nn.Linear(self.dim, v_dim, bias=False) - - self.dropout = nn.Dropout(p=self.dropout_rate) - - # Feedforward - self.fc = nn.Linear(out_dim, self.dim) - - def forward( - self, - x: Tensor, - context: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - rotary_embedding: Optional[RotaryEmbedding] = None, - ) -> Tensor: - """Computes the attention.""" - b, device = x.shape[0], x.device - - q = self.to_q(x) - q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) - k = self.to_k(context) if context is not None else self.to_k(x) - v = self.to_v(context) if context is not None else self.to_v(x) - - if rotary_embedding is not None: - q, k, v = map(lambda t: rotary_embedding.rotate(t), (q, k, v)) - - energy = einsum("b h i d, b j d -> b h i j", q, k) * self.scale - mask_value = -torch.finfo(energy.dtype).max - energy = apply_input_mask(b, k, energy, mask, mask_value, device) - if self.causal: - energy = apply_causal_mask(energy, mask, mask_value, device) - - attn = F.softmax(energy, dim=-1) - attn = self.dropout(attn) - out = einsum("b h i j, b j d -> b h i d", attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - out = self.fc(out) - return out - - -def apply_input_mask( - b: int, - k: Tensor, - energy: Tensor, - mask: Optional[Tensor], - mask_value: Tensor, - device: str, -) -> Tensor: - """Applies an input mask.""" - if mask is not None: - k_mask = torch.ones((b, k.shape[-2]), device=device).bool() - q_mask = rearrange(mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b j -> b () () j") - input_mask = q_mask * k_mask - - energy = energy.masked_fill_(~input_mask, mask_value) - return energy - - -def apply_causal_mask( - energy: Tensor, mask: Tensor, mask_value: Tensor, device: str -) -> Tensor: - """Applies a causal mask to the energy tensor.""" - i, j = energy.shape[-2:] - r = torch.arange(i, device=device) - mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") - mask = F.pad(mask, (j - i, 0), value=False) - energy.masked_fill_(mask, mask_value) - return energy |