summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings')
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
index 41290b4..2f58964 100644
--- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
@@ -1,7 +1,7 @@
"""Roatary embedding.
Stolen from lucidrains:
- https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+ https://github.com/lucidrains/rotary-embedding-torch
Explanation of roatary:
https://blog.eleuther.ai/rotary-embeddings/
@@ -15,16 +15,19 @@ from torch import Tensor
class RotaryEmbedding(nn.Module):
+ """Rotary positional embedding."""
+
def __init__(self, dim: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+ """Encodes tensor x with rotary embeddings."""
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
- return emb[None, :, :]
+ return rearrange(emb, "n d -> () () n d")
def rotate_half(x: Tensor) -> Tensor:
@@ -33,6 +36,7 @@ def rotate_half(x: Tensor) -> Tensor:
return torch.cat((-x2, x1), dim=-1)
-def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
- q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
- return q, k
+def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
+ seq_len = t.shape[-2]
+ freqs = freqs[:, :, -seq_len:]
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())