summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer')
-rw-r--r--text_recognizer/network/transformer/attend.py72
1 files changed, 37 insertions, 35 deletions
diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py
index d2c17b1..a5c23c6 100644
--- a/text_recognizer/network/transformer/attend.py
+++ b/text_recognizer/network/transformer/attend.py
@@ -26,10 +26,24 @@ class Attend(nn.Module):
else:
self.cuda_cfg = Config(False, True, True)
- def flash_attn(self, q: Tensor, k: Tensor, v: Tensor, causal: bool) -> Tensor:
+ def flash_attn(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ mask: Optional[Tensor],
+ causal: bool,
+ ) -> Tensor:
cfg = self.cuda_cfg if q.is_cuda else self.cpu_cfg
+ if causal:
+ i, j, device = q.shape[-2], k.shape[-2], q.device
+ causal_mask = create_causal_mask(i, j, device)
+ mask = mask & ~causal_mask
+ causal = False
with torch.backends.cuda.sdp_kernel(**cfg._asdict()):
- out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
+ out = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, is_causal=causal
+ )
return out
def attn(
@@ -37,23 +51,25 @@ class Attend(nn.Module):
q: Tensor,
k: Tensor,
v: Tensor,
+ mask: Optional[Tensor],
causal: bool,
- mask: Optional[Tensor] = None,
) -> Tensor:
- b = q.shape[0]
- energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+ q.shape[0]
+ weight = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
- mask_value = -torch.finfo(energy.dtype).max
+ mask_value = -torch.finfo(weight.dtype).max
if mask is not None:
- energy = apply_input_mask(b, k, energy, mask, mask_value)
+ weight = weight.masked_fill(~mask, mask_value)
if causal:
- energy = apply_causal_mask(energy, mask_value)
+ i, j, device = weight.shape[-2:], weight.device
+ causal_mask = create_causal_mask(i, j, device)
+ weight = weight.masked_fill(causal_mask, mask_value)
- attn = F.softmax(energy, dim=-1)
- attn = self.dropout(attn)
- return einsum("b h i j, b h j d -> b h i d", attn, v)
+ weight = F.softmax(weight, dim=-1)
+ weight = self.dropout(weight)
+ return einsum("b h i j, b h j d -> b h i d", weight, v)
def forward(
self,
@@ -63,32 +79,18 @@ class Attend(nn.Module):
causal: bool,
mask: Optional[Tensor] = None,
) -> Tensor:
+ if mask is not None:
+ mask = rearrange(mask, "b j -> b 1 1 j")
if self.use_flash:
- return self.flash_attn(q, k, v, causal)
+ return self.flash_attn(q, k, v, mask, causal)
else:
- return self.attn(q, k, v, causal, mask)
-
-
-def apply_input_mask(
- b: int,
- k: Tensor,
- energy: Tensor,
- mask: Optional[Tensor],
- mask_value: float,
-) -> Tensor:
- """Applies an input mask."""
- k_mask = torch.ones((b, k.shape[-2]), device=energy.device).bool()
- q_mask = rearrange(mask, "b i -> b () i ()")
- k_mask = rearrange(k_mask, "b j -> b () () j")
- input_mask = q_mask * k_mask
- return energy.masked_fill_(~input_mask, mask_value)
+ return self.attn(q, k, v, mask, causal)
-def apply_causal_mask(
- energy: Tensor,
- mask_value: float,
+def create_causal_mask(
+ i: int,
+ j: int,
+ device: torch.device,
) -> Tensor:
- """Applies a causal mask to the energy tensor."""
- i, j, device = *energy.shape[-2:], energy.device
- causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
- return energy.masked_fill(causal_mask, mask_value)
+ """Applies a causal mask to the weight tensor."""
+ return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)