summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/embedding/sincos.py
blob: ed6b0ab337e3b6c7a55b23a3d92b5bd290f922ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch


def sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature**omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)