summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb_callbacks.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 11d0936..68e4135 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -6,6 +6,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from torch.utils.data import DataLoader
+from torchvision.utils import make_grid
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -174,23 +175,20 @@ class LogReconstuctedImages(Callback):
samples = next(iter(dataloader))
imgs, _ = samples
- colums = ["input", "reconstruction"]
imgs = imgs.to(device=pl_module.device)
reconstructions = pl_module(imgs)[0]
+
data = [
- [wandb.Image(img), wandb.Image(rec)]
+ wandb.Image(
+ make_grid([img, rec]), caption="Left: Image, Right: Reconstruction"
+ )
+ # wandb.Image(rec, caption="Reconstruction"),
for img, rec in zip(
imgs[: self.num_samples], reconstructions[: self.num_samples]
)
]
- experiment.log(
- {
- f"Reconstructions/{experiment.name}/{stage}": wandb.Table(
- data=data, columns=colums
- )
- }
- )
+ experiment.log({f"Reconstructions/{experiment.name}/{stage}": data})
def on_sanity_check_start(
self, trainer: Trainer, pl_module: LightningModule