summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/embedding/l2_norm.py
blob: 0e48bca0493069513656772e16ed2bb81a50c08a (plain)
1
2
3
4
5
6
7
8
9
from einops import rearrange
import torch.nn.functional as F
from torch import Tensor


def l2_norm(t: Tensor, groups=1) -> Tensor:
    t = rearrange(t, "... (g d) -> ... g d", g=groups)
    t = F.normalize(t, p=2, dim=-1)
    return rearrange(t, "... g d -> ... (g d)")