From 65395f8b40c02ed2be6d438665917108741ad15c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:06:03 +0200 Subject: Rename vqloss to commitment loss in vqvae network --- text_recognizer/networks/vqvae/vqvae.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'text_recognizer/networks/vqvae') diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index f31b062..2d17e0f 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -29,7 +29,9 @@ class VQVAE(nn.Module): in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1 ) self.quantizer = VectorQuantizer( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + decay=decay, ) def encode(self, x: Tensor) -> Tensor: @@ -39,8 +41,8 @@ class VQVAE(nn.Module): def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: """Quantizes the encoded latent vectors.""" - z_q, vq_loss = self.quantizer(z_e) - return z_q, vq_loss + z_q, commitment_loss = self.quantizer(z_e) + return z_q, commitment_loss def decode(self, z_q: Tensor) -> Tensor: """Reconstructs input from latent codes.""" @@ -51,6 +53,6 @@ class VQVAE(nn.Module): def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" z_e = self.encode(x) - z_q, vq_loss = self.quantize(z_e) + z_q, commitment_loss = self.quantize(z_e) x_hat = self.decode(z_q) - return x_hat, vq_loss + return x_hat, commitment_loss -- cgit v1.2.3-70-g09d2