summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/utils.py
blob: ec979492fb3a23f17651345f94436b50aea3f93f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Helper functions for quantization."""
from typing import Tuple

import torch
from torch import einsum, Tensor
import torch.nn.functional as F


def sample_vectors(samples: Tensor, num: int) -> Tensor:
    """Subsamples a set of vectors."""
    B, device = samples.shape[0], samples.device
    if B >= num:
        indices = torch.randperm(B, device=device)[:num]
    else:
        indices = torch.randint(0, B, (num,), device=device)[:num]
    return samples[indices]


def norm(t: Tensor) -> Tensor:
    """Applies L2-normalization."""
    return F.normalize(t, p=2, dim=-1)


def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
    """Applies exponential moving average."""
    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))


def log(t: Tensor, eps: float = 1e-20) -> Tensor:
    return torch.log(t.clamp(min=eps))


def gumbel_noise(t: Tensor) -> Tensor:
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor:
    if temperature == 0:
        return t.argmax(dim=dim)
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim)


def orthgonal_loss_fn(t: Tensor) -> Tensor:
    # eq (2) from https://arxiv.org/abs/2112.00384
    n = t.shape[0]
    normed_codes = norm(t)
    identity = torch.eye(n, device=t.device)
    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)