From 65395f8b40c02ed2be6d438665917108741ad15c Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 30 Sep 2021 23:06:03 +0200
Subject: Rename vqloss to commitment loss in vqvae network

---
 text_recognizer/networks/vqvae/vqvae.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index f31b062..2d17e0f 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -29,7 +29,9 @@ class VQVAE(nn.Module):
             in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1
         )
         self.quantizer = VectorQuantizer(
-            num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
+            num_embeddings=num_embeddings,
+            embedding_dim=embedding_dim,
+            decay=decay,
         )
 
     def encode(self, x: Tensor) -> Tensor:
@@ -39,8 +41,8 @@ class VQVAE(nn.Module):
 
     def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
         """Quantizes the encoded latent vectors."""
-        z_q, vq_loss = self.quantizer(z_e)
-        return z_q, vq_loss
+        z_q, commitment_loss = self.quantizer(z_e)
+        return z_q, commitment_loss
 
     def decode(self, z_q: Tensor) -> Tensor:
         """Reconstructs input from latent codes."""
@@ -51,6 +53,6 @@ class VQVAE(nn.Module):
     def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
         """Compresses and decompresses input."""
         z_e = self.encode(x)
-        z_q, vq_loss = self.quantize(z_e)
+        z_q, commitment_loss = self.quantize(z_e)
         x_hat = self.decode(z_q)
-        return x_hat, vq_loss
+        return x_hat, commitment_loss
-- 
cgit v1.2.3-70-g09d2