summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer')
-rw-r--r--text_recognizer/network/transformer/embedding/rotary.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/embedding/rotary.py b/text_recognizer/network/transformer/embedding/rotary.py
new file mode 100644
index 0000000..2254f81
--- /dev/null
+++ b/text_recognizer/network/transformer/embedding/rotary.py
@@ -0,0 +1,25 @@
+import torch
+from torch import nn, einsum
+from einops import rearrange
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, max_seq_len, *, device):
+ seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
+ freqs = einsum("i , j -> i j", seq, self.inv_freq)
+ return torch.cat((freqs, freqs), dim=-1)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(pos, t):
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())