summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 21:34:53 +0100
commitb44de0e11281c723ec426f8bec8ca0897ecfe3ff (patch)
tree998841a3a681d3dedfbe8470c1b8544b4dcbe7a2 /text_recognizer/networks/quantizer/utils.py
parent3b2fb0fd977a6aff4dcf88e1a0f99faac51e05b1 (diff)
Remove VQVAE stuff, did not work...
Diffstat (limited to 'text_recognizer/networks/quantizer/utils.py')
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
1 files changed, 0 insertions, 26 deletions
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
deleted file mode 100644
index 0502d49..0000000
--- a/text_recognizer/networks/quantizer/utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Helper functions for quantization."""
-from typing import Tuple
-
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-
-
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
- """Subsamples a set of vectors."""
- B, device = samples.shape[0], samples.device
- if B >= num:
- indices = torch.randperm(B, device=device)[:num]
- else:
- indices = torch.randint(0, B, (num,), device=device)[:num]
- return samples[indices]
-
-
-def norm(t: Tensor) -> Tensor:
- """Applies L2-normalization."""
- return F.normalize(t, p=2, dim=-1)
-
-
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
- """Applies exponential moving average."""
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))