diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/quantizer.py')
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index a4f11f0..6fb57e8 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -11,13 +11,15 @@ import torch.nn.functional as F class EmbeddingEMA(nn.Module): + """Embedding for Exponential Moving Average (EMA).""" + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: super().__init__() weight = torch.zeros(num_embeddings, embedding_dim) nn.init.kaiming_uniform_(weight, nonlinearity="linear") self.register_buffer("weight", weight) - self.register_buffer("_cluster_size", torch.zeros(num_embeddings)) - self.register_buffer("_weight_avg", weight) + self.register_buffer("cluster_size", torch.zeros(num_embeddings)) + self.register_buffer("weight_avg", weight.clone()) class VectorQuantizer(nn.Module): @@ -81,16 +83,17 @@ class VectorQuantizer(nn.Module): return quantized_latent def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() - self.embedding._cluster_size.data.mul_(self.decay).add_( + self.embedding.cluster_size.data.mul_(self.decay).add_( batch_cluster_size, alpha=1 - self.decay ) - self.embedding._weight_avg.data.mul_(self.decay).add_( + self.embedding.weight_avg.data.mul_(self.decay).add_( batch_embedding_avg, alpha=1 - self.decay ) - new_embedding = self.embedding._weight_avg / ( - self.embedding._cluster_size + 1.0e-5 + new_embedding = self.embedding.weight_avg / ( + self.embedding.cluster_size + 1.0e-5 ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) |