From 5a85ba8dc81c3530f61b188d0cad3c2c82091bb9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 28 Oct 2021 21:21:19 +0200 Subject: Update rotary embedding --- text_recognizer/networks/transformer/embeddings/rotary.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) (limited to 'text_recognizer/networks/transformer') diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py index 2f58964..ef2b85d 100644 --- a/text_recognizer/networks/transformer/embeddings/rotary.py +++ b/text_recognizer/networks/transformer/embeddings/rotary.py @@ -6,9 +6,6 @@ Stolen from lucidrains: Explanation of roatary: https://blog.eleuther.ai/rotary-embeddings/ """ -from typing import Tuple - -from einops import rearrange import torch from torch import nn from torch import Tensor @@ -17,21 +14,21 @@ from torch import Tensor class RotaryEmbedding(nn.Module): """Rotary positional embedding.""" - def __init__(self, dim: int): + def __init__(self, dim: int) -> None: super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) - def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """Encodes tensor x with rotary embeddings.""" - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + freqs = self.inv_freqs + freqs = torch.einsum("..., f -> ... f", x.type(freqs.dtype), freqs) emb = torch.cat((freqs, freqs), dim=-1) - return rearrange(emb, "n d -> () () n d") + return emb def rotate_half(x: Tensor) -> Tensor: - x = rearrange(x, "... (j d) -> ... j d", j=2) + x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2)) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) -- cgit v1.2.3-70-g09d2