diff options
Diffstat (limited to 'text_recognizer/networks/quantizer/utils.py')
-rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py new file mode 100644 index 0000000..ec97949 --- /dev/null +++ b/text_recognizer/networks/quantizer/utils.py @@ -0,0 +1,50 @@ +"""Helper functions for quantization.""" +from typing import Tuple + +import torch +from torch import einsum, Tensor +import torch.nn.functional as F + + +def sample_vectors(samples: Tensor, num: int) -> Tensor: + """Subsamples a set of vectors.""" + B, device = samples.shape[0], samples.device + if B >= num: + indices = torch.randperm(B, device=device)[:num] + else: + indices = torch.randint(0, B, (num,), device=device)[:num] + return samples[indices] + + +def norm(t: Tensor) -> Tensor: + """Applies L2-normalization.""" + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: + """Applies exponential moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def log(t: Tensor, eps: float = 1e-20) -> Tensor: + return torch.log(t.clamp(min=eps)) + + +def gumbel_noise(t: Tensor) -> Tensor: + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + +def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor: + if temperature == 0: + return t.argmax(dim=dim) + return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim) + + +def orthgonal_loss_fn(t: Tensor) -> Tensor: + # eq (2) from https://arxiv.org/abs/2112.00384 + n = t.shape[0] + normed_codes = norm(t) + identity = torch.eye(n, device=t.device) + cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) + return ((cosine_sim - identity) ** 2).sum() / (n ** 2) |