diff options
Diffstat (limited to 'training/callbacks/wandb_callbacks.py')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 11d0936..68e4135 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -6,6 +6,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import DataLoader +from torchvision.utils import make_grid def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -174,23 +175,20 @@ class LogReconstuctedImages(Callback): samples = next(iter(dataloader)) imgs, _ = samples - colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) reconstructions = pl_module(imgs)[0] + data = [ - [wandb.Image(img), wandb.Image(rec)] + wandb.Image( + make_grid([img, rec]), caption="Left: Image, Right: Reconstruction" + ) + # wandb.Image(rec, caption="Reconstruction"), for img, rec in zip( imgs[: self.num_samples], reconstructions[: self.num_samples] ) ] - experiment.log( - { - f"Reconstructions/{experiment.name}/{stage}": wandb.Table( - data=data, columns=colums - ) - } - ) + experiment.log({f"Reconstructions/{experiment.name}/{stage}": data}) def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule |