diff options
Diffstat (limited to 'text_recognizer/network/transformer/attend.py')
-rw-r--r-- | text_recognizer/network/transformer/attend.py | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py new file mode 100644 index 0000000..4e643fb --- /dev/null +++ b/text_recognizer/network/transformer/attend.py @@ -0,0 +1,94 @@ +from typing import Optional +from collections import namedtuple + +import torch +from torch import Tensor, einsum, nn +from einops import rearrange +import torch.nn.functional as F + +Config = namedtuple( + "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] +) + + +class Attend(nn.Module): + def __init__(self, use_flash: bool) -> None: + super().__init__() + self.use_flash = use_flash + self.cpu_cfg = Config(True, True, True) + self.cuda_cfg = None + if not torch.cuda.is_available(): + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + if device_properties.major == 8 and device_properties.minor == 0: + self.cuda_cfg = Config(True, False, False) + else: + self.cuda_cfg = Config(False, True, True) + + def flash_attn(self, q: Tensor, k: Tensor, v: Tensor, causal: bool) -> Tensor: + cfg = self.cuda_cfg if q.is_cuda else self.cpu_cfg + with torch.backends.cuda.sdp_kernel(**cfg._asdict()): + out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) + return out + + def atten( + self, + q: Tensor, + k: Tensor, + v: Tensor, + causal: bool, + mask: Optional[Tensor] = None, + ) -> Tensor: + b = q.shape[0] + energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + + mask_value = -torch.finfo(energy.dtype).max + + if mask is not None: + energy = apply_input_mask(b, k, energy, mask, mask_value) + + if causal: + energy = apply_causal_mask(energy, mask_value) + + attn = F.softmax(energy, dim=-1) + attn = self.dropout(attn) + return einsum("b h i j, b h j d -> b h i d", attn, v) + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + causal: bool, + mask: Optional[Tensor] = None, + ) -> Tensor: + if self.use_flash: + return self.flash_attn(q, k, v, causal) + else: + return self.atten(q, k, v, causal, mask) + + +def apply_input_mask( + b: int, + k: Tensor, + energy: Tensor, + mask: Optional[Tensor], + mask_value: float, +) -> Tensor: + """Applies an input mask.""" + k_mask = torch.ones((b, k.shape[-2]), device=energy.device).bool() + q_mask = rearrange(mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + return energy.masked_fill_(~input_mask, mask_value) + + +def apply_causal_mask( + energy: Tensor, + mask_value: float, +) -> Tensor: + """Applies a causal mask to the energy tensor.""" + i, j, device = *energy.shape[-2:], energy.device + causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) + return energy.masked_fill(causal_mask, mask_value) |