From 9c829df67f0a874b2803769dc8ff3681a3c095b1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 11 Sep 2021 15:44:14 +0200 Subject: Rename vq loss to commitment loss --- text_recognizer/models/vqvae.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) (limited to 'text_recognizer/models/vqvae.py') diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 56229b3..92f28ad 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -11,7 +11,7 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - latent_loss_weight: float = attr.ib(default=0.25) + commitment: float = attr.ib(default=0.25) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -21,37 +21,35 @@ class VQVAELitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("train/vq_loss", vq_loss) + self.log("train/commitment_loss", commitment_loss) self.log("train/loss", loss) - - # self.train_acc(reconstructions, data) - # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("val/vq_loss", vq_loss) + self.log("val/commitment_loss", commitment_loss) self.log("val/loss", loss, prog_bar=True) - # self.val_acc(reconstructions, data) - # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss - self.log("test/vq_loss", vq_loss) + loss = loss + self.commitment * commitment_loss + + self.log("test/commitment_loss", commitment_loss) self.log("test/loss", loss) - # self.test_acc(reconstructions, data) - # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) -- cgit v1.2.3-70-g09d2