diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
commit | 0901bb8172fe56caa3eba9e4bf96ae0b164f9292 (patch) | |
tree | ad1b5964af91a5982fed59715f058586cd28f60d /text_recognizer/networks/quantizer/utils.py | |
parent | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (diff) |
Remove quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer/utils.py')
-rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 50 |
1 files changed, 0 insertions, 50 deletions
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py deleted file mode 100644 index ec97949..0000000 --- a/text_recognizer/networks/quantizer/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Helper functions for quantization.""" -from typing import Tuple - -import torch -from torch import einsum, Tensor -import torch.nn.functional as F - - -def sample_vectors(samples: Tensor, num: int) -> Tensor: - """Subsamples a set of vectors.""" - B, device = samples.shape[0], samples.device - if B >= num: - indices = torch.randperm(B, device=device)[:num] - else: - indices = torch.randint(0, B, (num,), device=device)[:num] - return samples[indices] - - -def norm(t: Tensor) -> Tensor: - """Applies L2-normalization.""" - return F.normalize(t, p=2, dim=-1) - - -def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: - """Applies exponential moving average.""" - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def log(t: Tensor, eps: float = 1e-20) -> Tensor: - return torch.log(t.clamp(min=eps)) - - -def gumbel_noise(t: Tensor) -> Tensor: - noise = torch.zeros_like(t).uniform_(0, 1) - return -log(-log(noise)) - - -def gumbel_sample(t: Tensor, temperature: float = 1.0, dim: int = -1) -> Tensor: - if temperature == 0: - return t.argmax(dim=dim) - return ((t / temperature) + gumbel_noise(t)).argmax(dim=dim) - - -def orthgonal_loss_fn(t: Tensor) -> Tensor: - # eq (2) from https://arxiv.org/abs/2112.00384 - n = t.shape[0] - normed_codes = norm(t) - identity = torch.eye(n, device=t.device) - cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) |