summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
commitd3afa310f77f47553586eeee58e3d3345a754e2c (patch)
tree08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /training/callbacks
parent65d5f6c694e73792e40ed693a1381a792da8d277 (diff)
New VQVAE
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb_callbacks.py69
1 files changed, 46 insertions, 23 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 906531f..c750e4b 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -5,6 +5,7 @@ import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch.utils.data import DataLoader
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -86,7 +87,11 @@ class LogTextPredictions(Callback):
self.ready = False
def _log_predictions(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the predicted text contained in the images."""
if not self.ready:
@@ -96,22 +101,20 @@ class LogTextPredictions(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, labels = samples
imgs = imgs.to(device=pl_module.device)
logits = pl_module(imgs)
mapping = pl_module.mapping
- columns = ["id", "image", "prediction", "truth"]
+ columns = ["image", "prediction", "truth"]
data = [
- [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
- for id, (img, pred, label) in enumerate(
- zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
+ [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
)
]
@@ -133,11 +136,17 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_predictions(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_predictions(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
class LogReconstuctedImages(Callback):
@@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback):
self.ready = False
def _log_reconstruction(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the reconstructions."""
if not self.ready:
@@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, _ = samples
+ colums = ["input", "reconstruction"]
imgs = imgs.to(device=pl_module.device)
- reconstructions = pl_module(imgs)
+ reconstructions = pl_module(imgs)[0]
+ data = [
+ [wandb.Image(img), wandb.Image(rec)]
+ for img, rec in zip(
+ imgs[: self.num_samples], reconstructions[: self.num_samples]
+ )
+ ]
experiment.log(
{
- f"Reconstructions/{experiment.name}/{stage}": [
- [wandb.Image(img), wandb.Image(rec),]
- for img, rec in zip(
- imgs[: self.num_samples], reconstructions[: self.num_samples],
- )
- ]
+ f"Reconstructions/{experiment.name}/{stage}": wandb.Table(
+ data=data, columns=colums
+ )
}
)
@@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_reconstruction(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_reconstruction(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )