summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/embeddings
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:21:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-28 21:21:19 +0200
commit5a85ba8dc81c3530f61b188d0cad3c2c82091bb9 (patch)
tree985b3640414b2309fe8063b2dac63c6ffc4f4487 /text_recognizer/networks/transformer/embeddings
parent2203a1ba52ab2f72682fcee738844ee9ec584bda (diff)
Update rotary embedding
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings')
-rw-r--r--text_recognizer/networks/transformer/embeddings/rotary.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py
index 2f58964..ef2b85d 100644
--- a/text_recognizer/networks/transformer/embeddings/rotary.py
+++ b/text_recognizer/networks/transformer/embeddings/rotary.py
@@ -6,9 +6,6 @@ Stolen from lucidrains:
Explanation of roatary:
https://blog.eleuther.ai/rotary-embeddings/
"""
-from typing import Tuple
-
-from einops import rearrange
import torch
from torch import nn
from torch import Tensor
@@ -17,21 +14,21 @@ from torch import Tensor
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding."""
- def __init__(self, dim: int):
+ def __init__(self, dim: int) -> None:
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
- def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+ def forward(self, x: Tensor) -> Tensor:
"""Encodes tensor x with rotary embeddings."""
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
- freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+ freqs = self.inv_freqs
+ freqs = torch.einsum("..., f -> ... f", x.type(freqs.dtype), freqs)
emb = torch.cat((freqs, freqs), dim=-1)
- return rearrange(emb, "n d -> () () n d")
+ return emb
def rotate_half(x: Tensor) -> Tensor:
- x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)