import torch def sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype)