blob: f5ec4ba35b0aeb2cc31f26ae43b25bafed1b29c3 (
plain)
1
2
3
4
5
6
7
8
9
|
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
def l2_norm(t: Tensor, groups=1) -> Tensor:
t = rearrange(t, "... (g d) -> ... g d", g=groups)
t = F.normalize(t, p=2, dim=-1)
return rearrange(t, "... g d -> ... (g d)")
|