From 9c829df67f0a874b2803769dc8ff3681a3c095b1 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 11 Sep 2021 15:44:14 +0200
Subject: Rename vq loss to commitment loss

---
 text_recognizer/networks/vqvae/quantizer.py | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index 6fb57e8..bba9b60 100644
--- a/text_recognizer/networks/vqvae/quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -34,7 +34,7 @@ class VectorQuantizer(nn.Module):
         self.decay = decay
         self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim)
 
-    def discretization_bottleneck(self, latent: Tensor) -> Tensor:
+    def _discretization_bottleneck(self, latent: Tensor) -> Tensor:
         """Computes the code nearest to the latent representation.
 
         First we compute the posterior categorical distribution, and then map
@@ -78,11 +78,11 @@ class VectorQuantizer(nn.Module):
             quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
         )
         if self.training:
-            self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
+            self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
 
         return quantized_latent
 
-    def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+    def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
         """Computes the EMA update to the codebook."""
         batch_cluster_size = one_hot_encoding.sum(axis=0)
         batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
@@ -97,7 +97,7 @@ class VectorQuantizer(nn.Module):
         ).unsqueeze(1)
         self.embedding.weight.data.copy_(new_embedding)
 
-    def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
+    def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
         """Vector Quantization loss.
 
         The vector quantization algorithm allows us to create a codebook. The VQ
@@ -119,10 +119,8 @@ class VectorQuantizer(nn.Module):
             Tensor: The combinded VQ loss.
 
         """
-        commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
-        # embedding_loss = F.mse_loss(quantized_latent, latent.detach())
-        # return embedding_loss + self.beta * commitment_loss
-        return commitment_loss
+        loss = F.mse_loss(quantized_latent.detach(), latent)
+        return loss
 
     def forward(self, latent: Tensor) -> Tensor:
         """Forward pass that returns the quantized vector and the vq loss."""
@@ -130,9 +128,9 @@ class VectorQuantizer(nn.Module):
         latent = rearrange(latent, "b d h w -> b h w d")
 
         # Maps latent to the nearest code in the codebook.
-        quantized_latent = self.discretization_bottleneck(latent)
+        quantized_latent = self._discretization_bottleneck(latent)
 
-        loss = self.vq_loss(latent, quantized_latent)
+        loss = self._commitment_loss(latent, quantized_latent)
 
         # Add residue to the quantized latent.
         quantized_latent = latent + (quantized_latent - latent).detach()
-- 
cgit v1.2.3-70-g09d2