diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-04 05:03:51 +0200 |
commit | d3afa310f77f47553586eeee58e3d3345a754e2c (patch) | |
tree | 08b7de1daf2550852d0a1e4d4d75202f14bb03d4 /training/callbacks | |
parent | 65d5f6c694e73792e40ed693a1381a792da8d277 (diff) |
New VQVAE
Diffstat (limited to 'training/callbacks')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 69 |
1 files changed, 46 insertions, 23 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 906531f..c750e4b 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -5,6 +5,7 @@ import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch.utils.data import DataLoader def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -86,7 +87,11 @@ class LogTextPredictions(Callback): self.ready = False def _log_predictions( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the predicted text contained in the images.""" if not self.ready: @@ -96,22 +101,20 @@ class LogTextPredictions(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, labels = samples imgs = imgs.to(device=pl_module.device) logits = pl_module(imgs) mapping = pl_module.mapping - columns = ["id", "image", "prediction", "truth"] + columns = ["image", "prediction", "truth"] data = [ - [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] - for id, (img, pred, label) in enumerate( - zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) + [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], ) ] @@ -133,11 +136,17 @@ class LogTextPredictions(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_predictions( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_predictions( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) class LogReconstuctedImages(Callback): @@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback): self.ready = False def _log_reconstruction( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the reconstructions.""" if not self.ready: @@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, _ = samples + colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) - reconstructions = pl_module(imgs) + reconstructions = pl_module(imgs)[0] + data = [ + [wandb.Image(img), wandb.Image(rec)] + for img, rec in zip( + imgs[: self.num_samples], reconstructions[: self.num_samples] + ) + ] experiment.log( { - f"Reconstructions/{experiment.name}/{stage}": [ - [wandb.Image(img), wandb.Image(rec),] - for img, rec in zip( - imgs[: self.num_samples], reconstructions[: self.num_samples], - ) - ] + f"Reconstructions/{experiment.name}/{stage}": wandb.Table( + data=data, columns=colums + ) } ) @@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_reconstruction( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_reconstruction( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) |