From 737000da5b44276512beffc1bdf81057df43ab2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 May 2021 22:27:42 +0200 Subject: Attention layer finished --- text_recognizer/models/vqvae.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef2213c..078235e 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel): """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: + def _log_prediction( + self, data: Tensor, reconstructions: Tensor, title: str + ) -> None: """Logs prediction on image with wandb.""" try: self.logger.experiment.log( { - "val_pred_examples": [ + title: [ wandb.Image(data[0]), wandb.Image(reconstructions[0]), ] @@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val_loss", loss, prog_bar=True) - self._log_prediction(data, reconstructions) + title = "val_pred_examples" + self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self._log_prediction(data, reconstructions) + title = "test_pred_examples" + self._log_prediction(data, reconstructions, title) -- cgit v1.2.3-70-g09d2