diff options
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index 6fb57e8..bba9b60 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -34,7 +34,7 @@ class VectorQuantizer(nn.Module): self.decay = decay self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) - def discretization_bottleneck(self, latent: Tensor) -> Tensor: + def _discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. First we compute the posterior categorical distribution, and then map @@ -78,11 +78,11 @@ class VectorQuantizer(nn.Module): quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w ) if self.training: - self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) + self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) return quantized_latent - def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() @@ -97,7 +97,7 @@ class VectorQuantizer(nn.Module): ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: """Vector Quantization loss. The vector quantization algorithm allows us to create a codebook. The VQ @@ -119,10 +119,8 @@ class VectorQuantizer(nn.Module): Tensor: The combinded VQ loss. """ - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - # embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - # return embedding_loss + self.beta * commitment_loss - return commitment_loss + loss = F.mse_loss(quantized_latent.detach(), latent) + return loss def forward(self, latent: Tensor) -> Tensor: """Forward pass that returns the quantized vector and the vq loss.""" @@ -130,9 +128,9 @@ class VectorQuantizer(nn.Module): latent = rearrange(latent, "b d h w -> b h w d") # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) + quantized_latent = self._discretization_bottleneck(latent) - loss = self.vq_loss(latent, quantized_latent) + loss = self._commitment_loss(latent, quantized_latent) # Add residue to the quantized latent. quantized_latent = latent + (quantized_latent - latent).detach() |