diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 72 |
1 files changed, 59 insertions, 13 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 8724691..623d680 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -1,9 +1,10 @@ """Implementes the attention module for the transformer.""" from typing import Optional, Tuple +from einops import rearrange from einops.layers.torch import Rearrange -import numpy as np import torch +from torch import einsum from torch import nn from torch import Tensor import torch.nn.functional as F @@ -34,7 +35,7 @@ class Attention(nn.Module): self.attn_fn = F.softmax # Feedforward - self.proj = nn.Linear(inner_dim, dim) + self.fc = nn.Linear(inner_dim, dim) @staticmethod def _apply_rotary_emb( @@ -47,8 +48,42 @@ class Attention(nn.Module): k = torch.cat((kl, kr), dim=-1) return q, k - def _cross_attention(self) -> Tensor: - pass + @staticmethod + def _compute_input_mask( + b: int, + n: int, + k: Tensor, + mask: Optional[Tensor], + context: Optional[Tensor], + context_mask: Optional[Tensor], + device: str, + ) -> Optional[Tensor]: + if any(x is not None for x in (mask, context_mask)): + q_mask = ( + mask if mask is not None else torch.ones((b, n), device=device).bool() + ) + k_mask = q_mask if context is not None else context_mask + k_mask = ( + torch.ones((b, k.shape[-2]), device=device).bool() + if k_mask is None + else k_mask + ) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b i -> b () () j") + return q_mask * k_mask + return + + @staticmethod + def _apply_causal_mask( + energy: Tensor, mask: Tensor, mask_value: Tensor, device: str + ) -> Tensor: + i, j = energy.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + energy.masked_fill_(mask, mask_value) + del mask + return energy def forward( self, @@ -67,14 +102,25 @@ class Attention(nn.Module): k, ) - input_mask = None - if any(x is not None for x in (mask, context_mask)): - q_mask = ( - mask - if mask is not None - else lambda: torch.ones((b, n), device=device).bool() - ) - pass + input_mask = self._compute_input_mask( + b, n, k, mask, context, context_mask, device + ) # Compute the attention - energy = (q @ k.transpose(-2, -1)) * self.scale + energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale + mask_value = -torch.finfo(energy.dtype).max + + # Apply input mask + if input_mask is not None: + energy = energy.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + energy = self._apply_causal_mask(energy, mask, mask_value, device) + + attn = self.attn_fn(energy, dim=-1) + attn = self.dropout(attn) + out = einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.fc(out) + return out, attn |