summaryrefslogtreecommitdiff
path: root/training/callbacks/wandb.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:53:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:53:37 +0200
commitabc2d60d69d115cdb34615d8bcb6c03ab6357141 (patch)
tree74f8536f7ca072f917fd924d2528ccaf0c273b49 /training/callbacks/wandb.py
parent617bf7f0285090b85817a398ef4bb871d4f616e9 (diff)
Refactor wandb callbacks
Diffstat (limited to 'training/callbacks/wandb.py')
-rw-r--r--training/callbacks/wandb.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py
new file mode 100644
index 0000000..d9bb9b8
--- /dev/null
+++ b/training/callbacks/wandb.py
@@ -0,0 +1,153 @@
+"""Weights and Biases callbacks."""
+from pathlib import Path
+from typing import Tuple
+
+import wandb
+from torch import Tensor
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def get_wandb_logger(trainer: Trainer) -> WandbLogger:
+ """Safely get W&B logger from Trainer."""
+
+ for logger in trainer.loggers:
+ if isinstance(logger, WandbLogger):
+ return logger
+
+ raise Exception("Weight and Biases logger not found for some reason...")
+
+
+class WatchModel(Callback):
+ """Make W&B watch the model at the beginning of the run."""
+
+ def __init__(
+ self,
+ log_params: str = "gradients",
+ log_freq: int = 100,
+ log_graph: bool = False,
+ ) -> None:
+ self.log_params = log_params
+ self.log_freq = log_freq
+ self.log_graph = log_graph
+
+ @rank_zero_only
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Watches model weights with wandb."""
+ logger = get_wandb_logger(trainer)
+ logger.watch(
+ model=trainer.model,
+ log=self.log_params,
+ log_freq=self.log_freq,
+ log_graph=self.log_graph,
+ )
+
+
+class UploadConfigAsArtifact(Callback):
+ """Upload all *.py files to W&B as an artifact, at the beginning of the run."""
+
+ def __init__(self) -> None:
+ self.config_dir = Path(".hydra/")
+
+ @rank_zero_only
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Uploads project code as an artifact."""
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+ artifact = wandb.Artifact("experiment-config", type="config")
+ for filepath in self.config_dir.rglob("*.yaml"):
+ artifact.add_file(str(filepath))
+
+ experiment.use_artifact(artifact)
+
+
+class UploadCheckpointsAsArtifact(Callback):
+ """Upload checkpoint to wandb as an artifact, at the end of a run."""
+
+ def __init__(
+ self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
+ ) -> None:
+ self.ckpt_dir = Path(ckpt_dir)
+ self.upload_best_only = upload_best_only
+
+ @rank_zero_only
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Uploads model checkpoint to W&B."""
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+ ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
+
+ if self.upload_best_only:
+ ckpts.add_file(trainer.checkpoint_callback.best_model_path)
+ else:
+ for ckpt in (self.ckpt_dir).rglob("*.ckpt"):
+ ckpts.add_file(ckpt)
+
+ experiment.use_artifact(ckpts)
+
+
+class ImageToCaptionLogger(Callback):
+ """Logs the image and output caption."""
+
+ def __init__(self, num_samples: int = 8, on_train: bool = True) -> None:
+ self.num_samples = num_samples
+ self.on_train = on_train
+ self._required_keys = ("predictions", "ground_truths")
+
+ def _log_captions(
+ self, trainer: Trainer, batch: Tuple[Tensor, Tensor], outputs: dict, key: str
+ ) -> None:
+ xs, _ = batch
+ preds, gts = outputs["predictions"], outputs["ground_truths"]
+ xs, preds, gts = (
+ list(xs[: self.num_samples]),
+ preds[: self.num_samples],
+ gts[: self.num_samples],
+ )
+ trainer.logger.log_image(key, xs, caption=preds)
+
+ @rank_zero_only
+ def on_train_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
+ ) -> None:
+ """Logs predictions on validation batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "train/predictions")
+
+ @rank_zero_only
+ def on_validation_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
+ *args,
+ # dataloader_idx: int,
+ ) -> None:
+ """Logs predictions on validation batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "val/predictions")
+
+ @rank_zero_only
+ def on_test_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
+ dataloader_idx: int,
+ ) -> None:
+ """Logs predictions on train batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "test/predictions")
+
+ def has_metrics(self, outputs: dict) -> bool:
+ return all(k in outputs.keys() for k in self._required_keys)