From 3a9ca4a230b59e9025216383664da8ef1780a3a0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:11:29 +0200 Subject: Add rotary embedding --- .../network/transformer/embedding/rotary.py | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 text_recognizer/network/transformer/embedding/rotary.py diff --git a/text_recognizer/network/transformer/embedding/rotary.py b/text_recognizer/network/transformer/embedding/rotary.py new file mode 100644 index 0000000..2254f81 --- /dev/null +++ b/text_recognizer/network/transformer/embedding/rotary.py @@ -0,0 +1,25 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) -- cgit v1.2.3-70-g09d2