summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-02 22:27:42 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-02 22:27:42 +0200
commit737000da5b44276512beffc1bdf81057df43ab2c (patch)
treed44fa30079d8db0534c14b6d53e8524e05673620 /text_recognizer/models/vqvae.py
parent1baeae6b414f71906bd1480d3ddc393ae878bd63 (diff)
Attention layer finished
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r--text_recognizer/models/vqvae.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index ef2213c..078235e 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -29,12 +29,14 @@ class LitVQVAEModel(LitBaseModel):
"""Forward pass with the transformer network."""
return self.network.predict(data)
- def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+ def _log_prediction(
+ self, data: Tensor, reconstructions: Tensor, title: str
+ ) -> None:
"""Logs prediction on image with wandb."""
try:
self.logger.experiment.log(
{
- "val_pred_examples": [
+ title: [
wandb.Image(data[0]),
wandb.Image(reconstructions[0]),
]
@@ -59,7 +61,8 @@ class LitVQVAEModel(LitBaseModel):
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
self.log("val_loss", loss, prog_bar=True)
- self._log_prediction(data, reconstructions)
+ title = "val_pred_examples"
+ self._log_prediction(data, reconstructions, title)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -67,4 +70,5 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self._log_prediction(data, reconstructions)
+ title = "test_pred_examples"
+ self._log_prediction(data, reconstructions, title)