From 88a52ea3dca450f7259efd537cb7aa5d8675178f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:11:48 +0200 Subject: Update attend --- text_recognizer/network/transformer/attend.py | 72 ++++++++++++++------------- 1 file changed, 37 insertions(+), 35 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/network/transformer/attend.py b/text_recognizer/network/transformer/attend.py index d2c17b1..a5c23c6 100644 --- a/text_recognizer/network/transformer/attend.py +++ b/text_recognizer/network/transformer/attend.py @@ -26,10 +26,24 @@ class Attend(nn.Module): else: self.cuda_cfg = Config(False, True, True) - def flash_attn(self, q: Tensor, k: Tensor, v: Tensor, causal: bool) -> Tensor: + def flash_attn( + self, + q: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor], + causal: bool, + ) -> Tensor: cfg = self.cuda_cfg if q.is_cuda else self.cpu_cfg + if causal: + i, j, device = q.shape[-2], k.shape[-2], q.device + causal_mask = create_causal_mask(i, j, device) + mask = mask & ~causal_mask + causal = False with torch.backends.cuda.sdp_kernel(**cfg._asdict()): - out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, is_causal=causal + ) return out def attn( @@ -37,23 +51,25 @@ class Attend(nn.Module): q: Tensor, k: Tensor, v: Tensor, + mask: Optional[Tensor], causal: bool, - mask: Optional[Tensor] = None, ) -> Tensor: - b = q.shape[0] - energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + q.shape[0] + weight = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale - mask_value = -torch.finfo(energy.dtype).max + mask_value = -torch.finfo(weight.dtype).max if mask is not None: - energy = apply_input_mask(b, k, energy, mask, mask_value) + weight = weight.masked_fill(~mask, mask_value) if causal: - energy = apply_causal_mask(energy, mask_value) + i, j, device = weight.shape[-2:], weight.device + causal_mask = create_causal_mask(i, j, device) + weight = weight.masked_fill(causal_mask, mask_value) - attn = F.softmax(energy, dim=-1) - attn = self.dropout(attn) - return einsum("b h i j, b h j d -> b h i d", attn, v) + weight = F.softmax(weight, dim=-1) + weight = self.dropout(weight) + return einsum("b h i j, b h j d -> b h i d", weight, v) def forward( self, @@ -63,32 +79,18 @@ class Attend(nn.Module): causal: bool, mask: Optional[Tensor] = None, ) -> Tensor: + if mask is not None: + mask = rearrange(mask, "b j -> b 1 1 j") if self.use_flash: - return self.flash_attn(q, k, v, causal) + return self.flash_attn(q, k, v, mask, causal) else: - return self.attn(q, k, v, causal, mask) - - -def apply_input_mask( - b: int, - k: Tensor, - energy: Tensor, - mask: Optional[Tensor], - mask_value: float, -) -> Tensor: - """Applies an input mask.""" - k_mask = torch.ones((b, k.shape[-2]), device=energy.device).bool() - q_mask = rearrange(mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b j -> b () () j") - input_mask = q_mask * k_mask - return energy.masked_fill_(~input_mask, mask_value) + return self.attn(q, k, v, mask, causal) -def apply_causal_mask( - energy: Tensor, - mask_value: float, +def create_causal_mask( + i: int, + j: int, + device: torch.device, ) -> Tensor: - """Applies a causal mask to the energy tensor.""" - i, j, device = *energy.shape[-2:], energy.device - causal_mask = torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) - return energy.masked_fill(causal_mask, mask_value) + """Applies a causal mask to the weight tensor.""" + return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) -- cgit v1.2.3-70-g09d2