summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-24 00:57:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-24 00:57:24 +0200
commit87e7b7d1a4eb35df7cb4484f379c186efd981d6b (patch)
tree50da295637ff9119805c5a82557b89aff20cf420 /text_recognizer/networks/transformer
parent18af813a7a736d460b902e0856b2b17a9a67703b (diff)
Refator attention
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/attention.py104
1 files changed, 48 insertions, 56 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 4e32065..e098b63 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -46,57 +46,6 @@ class Attention(nn.Module):
# Feedforward
self.fc = nn.Linear(inner_dim, self.dim)
- @staticmethod
- def _apply_rotary_emb(
- q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor
- ) -> Tuple[Tensor, Tensor, Tensor]:
- l = rotary_pos_emb.shape[-1]
- (ql, qr), (kl, kr), (vl, vr) = map(
- lambda t: (t[..., :l], t[..., l:]), (q, k, v)
- )
- ql, kl, vl = map(
- lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)
- )
- q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
- return q, k, v
-
- @staticmethod
- def _compute_input_mask(
- b: int,
- n: int,
- k: Tensor,
- mask: Optional[Tensor],
- context: Optional[Tensor],
- context_mask: Optional[Tensor],
- device: str,
- ) -> Optional[Tensor]:
- if any(x is not None for x in (mask, context_mask)):
- q_mask = (
- mask if mask is not None else torch.ones((b, n), device=device).bool()
- )
- k_mask = q_mask if context is None else context_mask
- k_mask = (
- torch.ones((b, k.shape[-2]), device=device).bool()
- if k_mask is None
- else k_mask
- )
- q_mask = rearrange(q_mask, "b i -> b () i ()")
- k_mask = rearrange(k_mask, "b j -> b () () j")
- return q_mask * k_mask
- return
-
- @staticmethod
- def _apply_causal_mask(
- energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
- ) -> Tensor:
- i, j = energy.shape[-2:]
- r = torch.arange(i, device=device)
- mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
- mask = F.pad(mask, (j - i, 0), value=False)
- energy.masked_fill_(mask, mask_value)
- del mask
- return energy
-
def forward(
self,
x: Tensor,
@@ -114,14 +63,12 @@ class Attention(nn.Module):
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
)
q, k, v = (
- self._apply_rotary_emb(q, k, v, rotary_pos_emb)
+ apply_rotary_emb(q, k, v, rotary_pos_emb)
if rotary_pos_emb is not None and context is None
else (q, k, v,)
)
- input_mask = self._compute_input_mask(
- b, n, k, mask, context, context_mask, device
- )
+ input_mask = compute_input_mask(b, n, k, mask, context, context_mask, device)
# Compute the attention
energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
@@ -133,7 +80,7 @@ class Attention(nn.Module):
del input_mask
if self.causal:
- energy = self._apply_causal_mask(energy, mask, mask_value, device)
+ energy = apply_causal_mask(energy, mask, mask_value, device)
attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
@@ -141,3 +88,48 @@ class Attention(nn.Module):
out = rearrange(out, "b h n d -> b n (h d)")
out = self.fc(out)
return out, attn
+
+
+def apply_rotary_emb(
+ q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor
+) -> Tuple[Tensor, Tensor, Tensor]:
+ l = rotary_pos_emb.shape[-1]
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
+ return q, k, v
+
+
+def compute_input_mask(
+ b: int,
+ n: int,
+ k: Tensor,
+ mask: Optional[Tensor],
+ context: Optional[Tensor],
+ context_mask: Optional[Tensor],
+ device: str,
+) -> Optional[Tensor]:
+ if any(x is not None for x in (mask, context_mask)):
+ q_mask = mask if mask is not None else torch.ones((b, n), device=device).bool()
+ k_mask = q_mask if context is None else context_mask
+ k_mask = (
+ torch.ones((b, k.shape[-2]), device=device).bool()
+ if k_mask is None
+ else k_mask
+ )
+ q_mask = rearrange(q_mask, "b i -> b () i ()")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
+ return q_mask * k_mask
+ return
+
+
+def apply_causal_mask(
+ energy: Tensor, mask: Tensor, mask_value: Tensor, device: str
+) -> Tensor:
+ i, j = energy.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
+ mask = F.pad(mask, (j - i, 0), value=False)
+ energy.masked_fill_(mask, mask_value)
+ del mask
+ return energy