summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:13:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:13:54 +0200
commitfb90a53b1235fd836dee74452f3f2a621e0f363a (patch)
treedaae44aa5e7c1309a41a059594ce0c3fc92cbc26 /text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
parent8c7a59d58e2ce6b18384c9fcdba2fd49e5450b0e (diff)
Rename transformer embeddings
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py')
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py42
1 files changed, 0 insertions, 42 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
deleted file mode 100644
index 2f58964..0000000
--- a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""Roatary embedding.
-
-Stolen from lucidrains:
- https://github.com/lucidrains/rotary-embedding-torch
-
-Explanation of roatary:
- https://blog.eleuther.ai/rotary-embeddings/
-"""
-from typing import Tuple
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class RotaryEmbedding(nn.Module):
- """Rotary positional embedding."""
-
- def __init__(self, dim: int):
- super().__init__()
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
-
- def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
- """Encodes tensor x with rotary embeddings."""
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
- freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
- return rearrange(emb, "n d -> () () n d")
-
-
-def rotate_half(x: Tensor) -> Tensor:
- x = rearrange(x, "... (j d) -> ... j d", j=2)
- x1, x2 = x.unbind(dim=-2)
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
- seq_len = t.shape[-2]
- freqs = freqs[:, :, -seq_len:]
- return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())