summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-18 18:12:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-18 18:12:54 +0200
commit8eec5f3d4ce41c2ba30f0f1ab3f506fc3c0dd263 (patch)
tree19244a467e90ed8b919cdf2912d9e1b141ea1801 /training/callbacks
parent2cc6aa059139b57057609817913ad515063c2eab (diff)
Formats
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb_callbacks.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index aa72480..dc59f19 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ def __init__(
+ self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False
+ ) -> None:
self.log = log
self.log_freq = log_freq
+ self.log_graph = log_graph
@rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
- logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+ logger.watch(
+ model=trainer.model,
+ log=self.log,
+ log_freq=self.log_freq,
+ log_graph=self.log_graph,
+ )
class UploadConfigAsArtifact(Callback):