summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/callbacks/wandb_callbacks.py73
1 files changed, 0 insertions, 73 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 0598fbf..978098b 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -146,76 +146,3 @@ class LogTextPredictions(Callback):
self._log_predictions(
stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
)
-
-
-class LogReconstuctedImages(Callback):
- """Log reconstructions of images."""
-
- def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None:
- self.num_samples = num_samples
- self.ready = False
- self.sigmoid = nn.Sigmoid() if use_sigmoid else None
-
- def _log_reconstruction(
- self,
- stage: str,
- trainer: Trainer,
- pl_module: LightningModule,
- dataloader: DataLoader,
- ) -> None:
- """Logs the reconstructions."""
- if not self.ready:
- return None
-
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
-
- # Get a validation batch from the validation dataloader.
- samples = next(iter(dataloader))
- imgs, _ = samples
-
- imgs = imgs.to(device=pl_module.device)
- reconstructions = pl_module(imgs)[0]
- reconstructions = (
- self.sigmoid(reconstructions)
- if self.sigmoid is not None
- else reconstructions
- )
-
- data = [
- wandb.Image(
- make_grid([img, rec]), caption="Left: Image, Right: Reconstruction"
- )
- # wandb.Image(rec, caption="Reconstruction"),
- for img, rec in zip(
- imgs[: self.num_samples], reconstructions[: self.num_samples]
- )
- ]
-
- experiment.log({f"Reconstructions/{experiment.name}/{stage}": data})
-
- def on_sanity_check_start(
- self, trainer: Trainer, pl_module: LightningModule
- ) -> None:
- """Sets ready attribute."""
- self.ready = False
-
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
-
- def on_validation_epoch_end(
- self, trainer: Trainer, pl_module: LightningModule
- ) -> None:
- """Logs predictions on validation epoch end."""
- dataloader = trainer.datamodule.val_dataloader()
- self._log_reconstruction(
- stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
-
- def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Logs predictions on train epoch end."""
- dataloader = trainer.datamodule.test_dataloader()
- self._log_reconstruction(
- stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )