summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py6
-rw-r--r--text_recognizer/networks/transformer/rotary_embedding.py39
2 files changed, 44 insertions, 1 deletions
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 5874e97..c50afc3 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -33,7 +33,10 @@ class PositionalEncoding(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Encodes the tensor with a postional embedding."""
- x = x + self.pe[:, : x.shape[1]]
+ # [T, B, D]
+ if x.shape[2] != self.pe.shape[2]:
+ raise ValueError(f"x shape does not match pe in the 3rd dim.")
+ x = x + self.pe[: x.shape[0]]
return self.dropout(x)
@@ -48,6 +51,7 @@ class PositionalEncoding2D(nn.Module):
pe = self.make_pe(hidden_dim, max_h, max_w)
self.register_buffer("pe", pe)
+ @staticmethod
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
pe_h = PositionalEncoding.make_pe(
diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py
new file mode 100644
index 0000000..5e80572
--- /dev/null
+++ b/text_recognizer/networks/transformer/rotary_embedding.py
@@ -0,0 +1,39 @@
+"""Roatary embedding.
+
+Stolen from lucidrains:
+ https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+
+Explanation of roatary:
+ https://blog.eleuther.ai/rotary-embeddings/
+
+"""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb[None, :, :]
+
+
+def rotate_half(x: Tensor) -> Tensor:
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+ q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+ return q, k