diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-11 15:44:14 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-11 15:44:14 +0200 |
commit | 9c829df67f0a874b2803769dc8ff3681a3c095b1 (patch) | |
tree | 974cff555d655a43f2a98830d6848adc89ead6f1 /text_recognizer/models | |
parent | 276c24bdc4817f2817b47b7a3a6bcfd9bb47b2ef (diff) |
Rename vq loss to commitment loss
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 56229b3..92f28ad 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -11,7 +11,7 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - latent_loss_weight: float = attr.ib(default=0.25) + commitment: float = attr.ib(default=0.25) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -21,37 +21,35 @@ class VQVAELitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("train/vq_loss", vq_loss) + self.log("train/commitment_loss", commitment_loss) self.log("train/loss", loss) - - # self.train_acc(reconstructions, data) - # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("val/vq_loss", vq_loss) + self.log("val/commitment_loss", commitment_loss) self.log("val/loss", loss, prog_bar=True) - # self.val_acc(reconstructions, data) - # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss - self.log("test/vq_loss", vq_loss) + loss = loss + self.commitment * commitment_loss + + self.log("test/commitment_loss", commitment_loss) self.log("test/loss", loss) - # self.test_acc(reconstructions, data) - # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) |