diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/vqvae.py')
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index f31b062..2d17e0f 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -29,7 +29,9 @@ class VQVAE(nn.Module): in_channels=embedding_dim, out_channels=hidden_dim, kernel_size=1 ) self.quantizer = VectorQuantizer( - num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + decay=decay, ) def encode(self, x: Tensor) -> Tensor: @@ -39,8 +41,8 @@ class VQVAE(nn.Module): def quantize(self, z_e: Tensor) -> Tuple[Tensor, Tensor]: """Quantizes the encoded latent vectors.""" - z_q, vq_loss = self.quantizer(z_e) - return z_q, vq_loss + z_q, commitment_loss = self.quantizer(z_e) + return z_q, commitment_loss def decode(self, z_q: Tensor) -> Tensor: """Reconstructs input from latent codes.""" @@ -51,6 +53,6 @@ class VQVAE(nn.Module): def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" z_e = self.encode(x) - z_q, vq_loss = self.quantize(z_e) + z_q, commitment_loss = self.quantize(z_e) x_hat = self.decode(z_q) - return x_hat, vq_loss + return x_hat, commitment_loss |