blob: 0502d49f4f13ff47d34a853d41d402d6032e1603 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
"""Helper functions for quantization."""
from typing import Tuple
import torch
from torch import Tensor
import torch.nn.functional as F
def sample_vectors(samples: Tensor, num: int) -> Tensor:
"""Subsamples a set of vectors."""
B, device = samples.shape[0], samples.device
if B >= num:
indices = torch.randperm(B, device=device)[:num]
else:
indices = torch.randint(0, B, (num,), device=device)[:num]
return samples[indices]
def norm(t: Tensor) -> Tensor:
"""Applies L2-normalization."""
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
"""Applies exponential moving average."""
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|