summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/embedding/sincos.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/network/transformer/embedding/sincos.py')
-rw-r--r--text_recognizer/network/transformer/embedding/sincos.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/embedding/sincos.py b/text_recognizer/network/transformer/embedding/sincos.py
new file mode 100644
index 0000000..ed6b0ab
--- /dev/null
+++ b/text_recognizer/network/transformer/embedding/sincos.py
@@ -0,0 +1,13 @@
+import torch
+
+
+def sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
+ assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
+ omega = torch.arange(dim // 4) / (dim // 4 - 1)
+ omega = 1.0 / (temperature**omega)
+
+ y = y.flatten()[:, None] * omega[None, :]
+ x = x.flatten()[:, None] * omega[None, :]
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
+ return pe.type(dtype)