diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:03 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-01 00:35:03 +0100 |
commit | 86dc185ec88555c52e36eb2b24d48e7ac76c8e5c (patch) | |
tree | c3cb5106b6b49d722d13044c3854a1c7ed99a1f5 /text_recognizer/networks/transformer/embeddings | |
parent | 81f37d413e94c8120e748d8d6447c01967bffccc (diff) |
Update rotary embedding
Diffstat (limited to 'text_recognizer/networks/transformer/embeddings')
-rw-r--r-- | text_recognizer/networks/transformer/embeddings/rotary.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/text_recognizer/networks/transformer/embeddings/rotary.py b/text_recognizer/networks/transformer/embeddings/rotary.py index ef2b85d..722478e 100644 --- a/text_recognizer/networks/transformer/embeddings/rotary.py +++ b/text_recognizer/networks/transformer/embeddings/rotary.py @@ -16,24 +16,26 @@ class RotaryEmbedding(nn.Module): def __init__(self, dim: int) -> None: super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq) + inv_freqs = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freqs", inv_freqs) def forward(self, x: Tensor) -> Tensor: """Encodes tensor x with rotary embeddings.""" - freqs = self.inv_freqs - freqs = torch.einsum("..., f -> ... f", x.type(freqs.dtype), freqs) + n = x.shape[-2] + t = torch.arange(n, device=x.device).type_as(self.inv_freqs) + freqs = torch.einsum("i , j -> i j", t, self.inv_freqs) emb = torch.cat((freqs, freqs), dim=-1) - return emb + return emb[None, :, :] def rotate_half(x: Tensor) -> Tensor: - x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2)) + if len(x.shape) == 3: + x = x.reshape((x.shape[0], -1, 2, x.shape[-1] // 2)) + else: + x = x.reshape((x.shape[0], x.shape[1], -1, 2, x.shape[-1] // 2)) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(t: Tensor, freqs: Tensor) -> Tensor: - seq_len = t.shape[-2] - freqs = freqs[:, :, -seq_len:] return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) |