summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/vqvae.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index f31b062..2d17e0f 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -29,7 +29,9 @@ class VQVAE(nn.Module):
in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1
)
self.quantizer = VectorQuantizer(
- num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
+ num_embeddings=num_embeddings,
+ embedding_dim=embedding_dim,
+ decay=decay,
)
def encode(self, x: Tensor) -> Tensor:
@@ -39,8 +41,8 @@ class VQVAE(nn.Module):
def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]:
"""Quantizes the encoded latent vectors."""
- z_q, vq_loss = self.quantizer(z_e)
- return z_q, vq_loss
+ z_q, commitment_loss = self.quantizer(z_e)
+ return z_q, commitment_loss
def decode(self, z_q: Tensor) -> Tensor:
"""Reconstructs input from latent codes."""
@@ -51,6 +53,6 @@ class VQVAE(nn.Module):
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
z_e = self.encode(x)
- z_q, vq_loss = self.quantize(z_e)
+ z_q, commitment_loss = self.quantize(z_e)
x_hat = self.decode(z_q)
- return x_hat, vq_loss
+ return x_hat, commitment_loss