From f132ea7e0299a4cdc57245d8339f10411c13bfbc Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 18 Sep 2021 17:41:32 +0200 Subject: Update wandb reconsturction logging --- training/callbacks/wandb_callbacks.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 11d0936..68e4135 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -6,6 +6,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import DataLoader +from torchvision.utils import make_grid def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -174,23 +175,20 @@ class LogReconstuctedImages(Callback): samples = next(iter(dataloader)) imgs, _ = samples - colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) reconstructions = pl_module(imgs)[0] + data = [ - [wandb.Image(img), wandb.Image(rec)] + wandb.Image( + make_grid([img, rec]), caption="Left: Image, Right: Reconstruction" + ) + # wandb.Image(rec, caption="Reconstruction"), for img, rec in zip( imgs[: self.num_samples], reconstructions[: self.num_samples] ) ] - experiment.log( - { - f"Reconstructions/{experiment.name}/{stage}": wandb.Table( - data=data, columns=colums - ) - } - ) + experiment.log({f"Reconstructions/{experiment.name}/{stage}": data}) def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule -- cgit v1.2.3-70-g09d2