summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/quantizer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
commit3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch)
tree136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/quantizer.py
parent1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff)
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/quantizer.py')
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index a4f11f0..6fb57e8 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -11,13 +11,15 @@ import torch.nn.functional as F
class EmbeddingEMA(nn.Module):
+ """Embedding for Exponential Moving Average (EMA)."""
+
def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
super().__init__()
weight = torch.zeros(num_embeddings, embedding_dim)
nn.init.kaiming_uniform_(weight, nonlinearity="linear")
self.register_buffer("weight", weight)
- self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
- self.register_buffer("_weight_avg", weight)
+ self.register_buffer("cluster_size", torch.zeros(num_embeddings))
+ self.register_buffer("weight_avg", weight.clone())
class VectorQuantizer(nn.Module):
@@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module):
return quantized_latent
def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+ """Computes the EMA update to the codebook."""
batch_cluster_size = one_hot_encoding.sum(axis=0)
batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
- self.embedding._cluster_size.data.mul_(self.decay).add_(
+ self.embedding.cluster_size.data.mul_(self.decay).add_(
batch_cluster_size, alpha=1 - self.decay
)
- self.embedding._weight_avg.data.mul_(self.decay).add_(
+ self.embedding.weight_avg.data.mul_(self.decay).add_(
batch_embedding_avg, alpha=1 - self.decay
)
- new_embedding = self.embedding._weight_avg / (
- self.embedding._cluster_size + 1.0e-5
+ new_embedding = self.embedding.weight_avg / (
+ self.embedding.cluster_size + 1.0e-5
).unsqueeze(1)
self.embedding.weight.data.copy_(new_embedding)