summaryrefslogtreecommitdiff
path: root/text_recognizer/network/transformer/embedding/l2_norm.py
blob: f5ec4ba35b0aeb2cc31f26ae43b25bafed1b29c3 (plain)
1
2
3
4
5
6
7
8
9
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor


def l2_norm(t: Tensor, groups=1) -> Tensor:
    t = rearrange(t, "... (g d) -> ... g d", g=groups)
    t = F.normalize(t, p=2, dim=-1)
    return rearrange(t, "... g d -> ... (g d)")