summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/utils.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:58 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:42:58 +0100
commit2417288c9fe96264da708ce8d13ac7bc2faf83e3 (patch)
treea74fb7a8502b1642b71240608706efda14dee3f9 /text_recognizer/networks/quantizer/utils.py
parent2cb2c5b38f0711267fecfe9c5e10940f4b4f79fc (diff)
Add new quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer/utils.py')
-rw-r--r--text_recognizer/networks/quantizer/utils.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/text_recognizer/networks/quantizer/utils.py b/text_recognizer/networks/quantizer/utils.py
new file mode 100644
index 0000000..0502d49
--- /dev/null
+++ b/text_recognizer/networks/quantizer/utils.py
@@ -0,0 +1,26 @@
+"""Helper functions for quantization."""
+from typing import Tuple
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+
+
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
+ """Subsamples a set of vectors."""
+ B, device = samples.shape[0], samples.device
+ if B >= num:
+ indices = torch.randperm(B, device=device)[:num]
+ else:
+ indices = torch.randint(0, B, (num,), device=device)[:num]
+ return samples[indices]
+
+
+def norm(t: Tensor) -> Tensor:
+ """Applies L2-normalization."""
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
+ """Applies exponential moving average."""
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))