diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-18 17:41:32 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-18 17:41:32 +0200 |
commit | f132ea7e0299a4cdc57245d8339f10411c13bfbc (patch) | |
tree | 1ae83b6b415d5c2f35da50563f2403c423734ff4 /training | |
parent | 22a20270b86ca4159b4d338f8cd789e11867a562 (diff) |
Update wandb reconsturction logging
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 11d0936..68e4135 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -6,6 +6,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import DataLoader +from torchvision.utils import make_grid def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -174,23 +175,20 @@ class LogReconstuctedImages(Callback): samples = next(iter(dataloader)) imgs, _ = samples - colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) reconstructions = pl_module(imgs)[0] + data = [ - [wandb.Image(img), wandb.Image(rec)] + wandb.Image( + make_grid([img, rec]), caption="Left: Image, Right: Reconstruction" + ) + # wandb.Image(rec, caption="Reconstruction"), for img, rec in zip( imgs[: self.num_samples], reconstructions[: self.num_samples] ) ] - experiment.log( - { - f"Reconstructions/{experiment.name}/{stage}": wandb.Table( - data=data, columns=colums - ) - } - ) + experiment.log({f"Reconstructions/{experiment.name}/{stage}": data}) def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule |