diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-26 00:00:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-26 00:00:55 +0200 |
commit | a6c42e6f7cb70c1a06e46716f141c8f793a64e04 (patch) | |
tree | 26428f09b95e54c76001ba6e9ec1a4defb371db0 | |
parent | c8b0cf3af0bfadbc8dfa6ecb6dd8a339b6113e88 (diff) |
Remove reconstruction wandb callback
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 73 |
1 files changed, 0 insertions, 73 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 0598fbf..978098b 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -146,76 +146,3 @@ class LogTextPredictions(Callback): self._log_predictions( stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader ) - - -class LogReconstuctedImages(Callback): - """Log reconstructions of images.""" - - def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None: - self.num_samples = num_samples - self.ready = False - self.sigmoid = nn.Sigmoid() if use_sigmoid else None - - def _log_reconstruction( - self, - stage: str, - trainer: Trainer, - pl_module: LightningModule, - dataloader: DataLoader, - ) -> None: - """Logs the reconstructions.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(dataloader)) - imgs, _ = samples - - imgs = imgs.to(device=pl_module.device) - reconstructions = pl_module(imgs)[0] - reconstructions = ( - self.sigmoid(reconstructions) - if self.sigmoid is not None - else reconstructions - ) - - data = [ - wandb.Image( - make_grid([img, rec]), caption="Left: Image, Right: Reconstruction" - ) - # wandb.Image(rec, caption="Reconstruction"), - for img, rec in zip( - imgs[: self.num_samples], reconstructions[: self.num_samples] - ) - ] - - experiment.log({f"Reconstructions/{experiment.name}/{stage}": data}) - - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs predictions on validation epoch end.""" - dataloader = trainer.datamodule.val_dataloader() - self._log_reconstruction( - stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - - def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - dataloader = trainer.datamodule.test_dataloader() - self._log_reconstruction( - stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) |