summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index aa72480..dc59f19 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
- def __init__(self, log: str = "gradients", log_freq: int = 100) -> None:
+ def __init__(
+ self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False
+ ) -> None:
self.log = log
self.log_freq = log_freq
+ self.log_graph = log_graph
@rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Watches model weights with wandb."""
logger = get_wandb_logger(trainer)
- logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+ logger.watch(
+ model=trainer.model,
+ log=self.log,
+ log_freq=self.log_freq,
+ log_graph=self.log_graph,
+ )
class UploadConfigAsArtifact(Callback):