summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/embedding/sincos.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:14 +0200
commit49ca6ade1a19f7f9c702171537fe4be0dfcda66d (patch)
tree20062ed1910758481f3d5fff11159706c7b990c6 /text_recognizer/network/transformer/embedding/sincos.py
parent0421daf6bd97596703f426ba61c401599b538eeb (diff)
Rename and add flash atten
Diffstat (limited to 'text_recognizer/network/transformer/embedding/sincos.py')
-rw-r--r--text_recognizer/network/transformer/embedding/sincos.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/text_recognizer/network/transformer/embedding/sincos.py b/text_recognizer/network/transformer/embedding/sincos.py
new file mode 100644
index 0000000..ed6b0ab
--- /dev/null
+++ b/text_recognizer/network/transformer/embedding/sincos.py
@@ -0,0 +1,13 @@
+import torch
+
+
+def sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
+ assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
+ omega = torch.arange(dim // 4) / (dim // 4 - 1)
+ omega = 1.0 / (temperature**omega)
+
+ y = y.flatten()[:, None] * omega[None, :]
+ x = x.flatten()[:, None] * omega[None, :]
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
+ return pe.type(dtype)