diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-09 22:46:09 +0200 |
commit | c9c60678673e19ad3367339eb8e7a093e5a98474 (patch) | |
tree | b787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py | |
parent | a2a3133ed5da283888efbdb9924d0e3733c274c8 (diff) |
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py')
-rw-r--r-- | text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py new file mode 100644 index 0000000..5e80572 --- /dev/null +++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py @@ -0,0 +1,39 @@ +"""Roatary embedding. + +Stolen from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +Explanation of roatary: + https://blog.eleuther.ai/rotary-embeddings/ + +""" +from typing import Tuple + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor: + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb[None, :, :] + + +def rotate_half(x: Tensor) -> Tensor: + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]: + q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) + return q, k |