summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/embeddings/rotary.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py
index ef2b85d..722478e 100644
--- a/text_recognizer/networks/transformer/embeddings/rotary.py
+++ b/text_recognizer/networks/transformer/embeddings/rotary.py
@@ -16,24 +16,26 @@ class RotaryEmbedding(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
+ inv_freqs = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freqs", inv_freqs)
def forward(self, x: Tensor) -> Tensor:
"""Encodes tensor x with rotary embeddings."""
- freqs = self.inv_freqs
- freqs = torch.einsum("..., f -> ... f", x.type(freqs.dtype), freqs)
+ n = x.shape[-2]
+ t = torch.arange(n, device=x.device).type_as(self.inv_freqs)
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freqs)
emb = torch.cat((freqs, freqs), dim=-1)
- return emb
+ return emb[None, :, :]
def rotate_half(x: Tensor) -> Tensor:
- x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2))
+ if len(x.shape) == 3:
+ x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2))
+ else:
+ x = x.reshape((x.shape[0], x.shape[1], -1, 2, x.shape[-1] // 2))
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor:
- seq_len = t.shape[-2]
- freqs = freqs[:, :, -seq_len:]
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())