From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- text_recognizer/network/transformer/embedding/sincos.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 text_recognizer/network/transformer/embedding/sincos.py (limited to 'text_recognizer/network/transformer/embedding/sincos.py') diff --git a/text_recognizer/network/transformer/embedding/sincos.py b/text_recognizer/network/transformer/embedding/sincos.py new file mode 100644 index 0000000..ed6b0ab --- /dev/null +++ b/text_recognizer/network/transformer/embedding/sincos.py @@ -0,0 +1,13 @@ +import torch + + +def sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature**omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) -- cgit v1.2.3-70-g09d2