diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:13:54 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:13:54 +0200 |
commit | fb90a53b1235fd836dee74452f3f2a621e0f363a (patch) | |
tree | daae44aa5e7c1309a41a059594ce0c3fc92cbc26 /text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py | |
parent | 8c7a59d58e2ce6b18384c9fcdba2fd49e5450b0e (diff) |
Rename transformer embeddings
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py')
-rw-r--r-- | text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py | 42 |
1 files changed, 0 insertions, 42 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py deleted file mode 100644 index 2f58964..0000000 --- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Roatary embedding. - -Stolen from lucidrains: - https://github.com/lucidrains/rotary-embedding-torch - -Explanation of roatary: - https://blog.eleuther.ai/rotary-embeddings/ -""" -from typing import Tuple - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding.""" - - def __init__(self, dim: int): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: - """Encodes tensor x with rotary embeddings.""" - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - return rearrange(emb, "n d -> () () n d") - - -def rotate_half(x: Tensor) -> Tensor: - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: - seq_len = t.shape[-2] - freqs = freqs[:, :, -seq_len:] - return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) |