diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 22:27:42 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-05-02 22:27:42 +0200 |
commit | 737000da5b44276512beffc1bdf81057df43ab2c (patch) | |
tree | d44fa30079d8db0534c14b6d53e8524e05673620 /text_recognizer/models/vqvae.py | |
parent | 1baeae6b414f71906bd1480d3ddc393ae878bd63 (diff) |
Attention layer finished
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r-- | text_recognizer/models/vqvae.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef2213c..078235e 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel): """Forward pass with the transformer network.""" return self.network.predict(data) - def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None: + def _log_prediction( + self, data: Tensor, reconstructions: Tensor, title: str + ) -> None: """Logs prediction on image with wandb.""" try: self.logger.experiment.log( { - "val_pred_examples": [ + title: [ wandb.Image(data[0]), wandb.Image(reconstructions[0]), ] @@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val_loss", loss, prog_bar=True) - self._log_prediction(data, reconstructions) + title = "val_pred_examples" + self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - self._log_prediction(data, reconstructions) + title = "test_pred_examples" + self._log_prediction(data, reconstructions, title) |