summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/attend.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/attend.py')
-rw-r--r--text_recognizer/network/transformer/attend.py94
1 files changed, 94 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py
new file mode 100644
index 0000000..4e643fb
--- /dev/null
+++ b/text_recognizer/network/transformer/attend.py
@@ -0,0 +1,94 @@
+from typing import Optional
+from collections import namedtuple
+
+import torch
+from torch import Tensor, einsum, nn
+from einops import rearrange
+import torch.nn.functional as F
+
+Config = namedtuple(
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
+)
+
+
+class Attend(nn.Module):
+ def __init__(self, use_flash: bool) -> None:
+ super().__init__()
+ self.use_flash = use_flash
+ self.cpu_cfg = Config(True, True, True)
+ self.cuda_cfg = None
+ if not torch.cuda.is_available():
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
+ if device_properties.major == 8 and device_properties.minor == 0:
+ self.cuda_cfg = Config(True, False, False)
+ else:
+ self.cuda_cfg = Config(False, True, True)
+
+ def flash_attn(self, q: Tensor, k: Tensor, v: Tensor, causal: bool) -> Tensor:
+ cfg = self.cuda_cfg if q.is_cuda else self.cpu_cfg
+ with torch.backends.cuda.sdp_kernel(**cfg._asdict()):
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
+ return out
+
+ def atten(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ causal: bool,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ b = q.shape[0]
+ energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+
+ mask_value = -torch.finfo(energy.dtype).max
+
+ if mask is not None:
+ energy = apply_input_mask(b, k, energy, mask, mask_value)
+
+ if causal:
+ energy = apply_causal_mask(energy, mask_value)
+
+ attn = F.softmax(energy, dim=-1)
+ attn = self.dropout(attn)
+ return einsum("b h i j, b h j d -> b h i d", attn, v)
+
+ def forward(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ causal: bool,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if self.use_flash:
+ return self.flash_attn(q, k, v, causal)
+ else:
+ return self.atten(q, k, v, causal, mask)
+
+
+def apply_input_mask(
+ b: int,
+ k: Tensor,
+ energy: Tensor,
+ mask: Optional[Tensor],
+ mask_value: float,
+) -> Tensor:
+ """Applies an input mask."""
+ k_mask = torch.ones((b, k.shape[-2]), device=energy.device).bool()
+ q_mask = rearrange(mask, "b i -> b () i ()")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
+ input_mask = q_mask * k_mask
+ return energy.masked_fill_(~input_mask, mask_value)
+
+
+def apply_causal_mask(
+ energy: Tensor,
+ mask_value: float,
+) -> Tensor:
+ """Applies a causal mask to the energy tensor."""
+ i, j, device = *energy.shape[-2:], energy.device
+ causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
+ return energy.masked_fill(causal_mask, mask_value)