diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:21 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:21 +0100 |
commit | 7808b54b5bd146bb3671bee5d4540513826e96ea (patch) | |
tree | ce9881d906d5207fad9ad744e1089c3e687ca0fd /text_recognizer | |
parent | 86dc185ec88555c52e36eb2b24d48e7ac76c8e5c (diff) |
Fix self attention module
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/networks/transformer/attention.py | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index d687056..b73fec0 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -9,7 +9,10 @@ from torch import nn from torch import Tensor import torch.nn.functional as F -from text_recognizer.networks.transformer.embeddings.rotary import apply_rotary_pos_emb +from text_recognizer.networks.transformer.embeddings.rotary import ( + RotaryEmbedding, + rotate_half, +) @attr.s(eq=False) @@ -25,15 +28,15 @@ class Attention(nn.Module): causal: bool = attr.ib(default=False) dim_head: int = attr.ib(default=64) dropout_rate: float = attr.ib(default=0.0) + rotary_embedding: Optional[RotaryEmbedding] = attr.ib(default=None) scale: float = attr.ib(init=False) dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) - qkv_fn: nn.Sequential = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" self.scale = self.dim ** -0.5 - inner_dim = self.dim * self.dim_head + inner_dim = self.num_heads * self.dim_head self.query = nn.Linear(self.dim, inner_dim, bias=False) self.key = nn.Linear(self.dim, inner_dim, bias=False) @@ -50,7 +53,6 @@ class Attention(nn.Module): context: Optional[Tensor] = None, mask: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, - rotary_pos_emb: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Computes the attention.""" b, n, _, device = *x.shape, x.device @@ -61,11 +63,10 @@ class Attention(nn.Module): q, k, v = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) ) - q, k, v = ( - apply_rotary_emb(q, k, v, rotary_pos_emb) - if rotary_pos_emb is not None and context is None - else (q, k, v,) - ) + + if self.rotary_embedding is not None: + embedding = self.rotary_embedding(q) + q, k, v = _apply_rotary_emb(q, k, v, embedding[None, ...]) energy = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale mask_value = -torch.finfo(energy.dtype).max @@ -83,16 +84,12 @@ class Attention(nn.Module): return out, attn -def apply_rotary_emb( - q: Tensor, k: Tensor, v: Tensor, rotary_pos_emb: Tensor +def _apply_rotary_emb( + q: Tensor, k: Tensor, v: Tensor, freqs: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: - """Applies rotary embedding.""" - emb_len = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map( - lambda t: (t[..., :emb_len], t[..., emb_len:]), (q, k, v) + q, k, v = map( + lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k, v) ) - ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) return q, k, v |