diff options
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r-- | text_recognizer/models/vqvae.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef2213c..078235e 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel): """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: + def _log_prediction( + self, data: Tensor, reconstructions: Tensor, title: str + ) -> None: """Logs prediction on image with wandb.""" try: self.logger.experiment.log( { - "val_pred_examples": [ + title: [ wandb.Image(data[0]), wandb.Image(reconstructions[0]), ] @@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val_loss", loss, prog_bar=True) - self._log_prediction(data, reconstructions) + title = "val_pred_examples" + self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self._log_prediction(data, reconstructions) + title = "test_pred_examples" + self._log_prediction(data, reconstructions, title) |