diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-03 00:31:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-03 00:31:45 +0200 |
commit | 1d5e1e57c324e5c5f3521abc7c2173f84104a080 (patch) | |
tree | 20cc25f1f01a9129907cea95eca15ac38a2f3791 /text_recognizer/networks | |
parent | aa11bd46bbb7237680ab2e513dfb429e27de2536 (diff) |
Refactor rotary embedding
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py index 41290b4..2f58964 100644 --- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py +++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py @@ -1,7 +1,7 @@ """Roatary embedding. Stolen from lucidrains: - https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + https://github.com/lucidrains/rotary-embedding-torch Explanation of roatary: https://blog.eleuther.ai/rotary-embeddings/ @@ -15,16 +15,19 @@ from torch import Tensor class RotaryEmbedding(nn.Module): + """Rotary positional embedding.""" + def __init__(self, dim: int): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + """Encodes tensor x with rotary embeddings.""" t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i , j -> i j", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - return emb[None, :, :] + return rearrange(emb, "n d -> () () n d") def rotate_half(x: Tensor) -> Tensor: @@ -33,6 +36,7 @@ def rotate_half(x: Tensor) -> Tensor: return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: - q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) - return q, k +def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) |