diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:34:53 +0100 |
commit | b44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch) | |
tree | 998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/networks/quantizer/utils.py | |
parent | 3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff) |
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/networks/quantizer/utils.py')
-rw-r--r-- | text_recognizer/networks/quantizer/utils.py | 26 |
1 files changed, 0 insertions, 26 deletions
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py deleted file mode 100644 index 0502d49..0000000 --- a/text_recognizer/networks/quantizer/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Helper functions for quantization.""" -from typing import Tuple - -import torch -from torch import Tensor -import torch.nn.functional as F - - -def sample_vectors(samples: Tensor, num: int) -> Tensor: - """Subsamples a set of vectors.""" - B, device = samples.shape[0], samples.device - if B >= num: - indices = torch.randperm(B, device=device)[:num] - else: - indices = torch.randint(0, B, (num,), device=device)[:num] - return samples[indices] - - -def norm(t: Tensor) -> Tensor: - """Applies L2-normalization.""" - return F.normalize(t, p=2, dim=-1) - - -def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None: - """Applies exponential moving average.""" - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) |