diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-06 02:42:45 +0200 |
commit | 3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch) | |
tree | 136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /text_recognizer/networks/vqvae/quantizer.py | |
parent | 1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff) |
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'text_recognizer/networks/vqvae/quantizer.py')
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index a4f11f0..6fb57e8 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -11,13 +11,15 @@ import torch.nn.functional as F class EmbeddingEMA(nn.Module): + """Embedding for Exponential Moving Average (EMA).""" + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: super().__init__() weight = torch.zeros(num_embeddings, embedding_dim) nn.init.kaiming_uniform_(weight, nonlinearity="linear") self.register_buffer("weight", weight) - self.register_buffer("_cluster_size", torch.zeros(num_embeddings)) - self.register_buffer("_weight_avg", weight) + self.register_buffer("cluster_size", torch.zeros(num_embeddings)) + self.register_buffer("weight_avg", weight.clone()) class VectorQuantizer(nn.Module): @@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module): return quantized_latent def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() - self.embedding._cluster_size.data.mul_(self.decay).add_( + self.embedding.cluster_size.data.mul_(self.decay).add_( batch_cluster_size, alpha=1 - self.decay ) - self.embedding._weight_avg.data.mul_(self.decay).add_( + self.embedding.weight_avg.data.mul_(self.decay).add_( batch_embedding_avg, alpha=1 - self.decay ) - new_embedding = self.embedding._weight_avg / ( - self.embedding._cluster_size + 1.0e-5 + new_embedding = self.embedding.weight_avg / ( + self.embedding.cluster_size + 1.0e-5 ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) |