From 8eec5f3d4ce41c2ba30f0f1ab3f506fc3c0dd263 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 18 Sep 2022 18:12:54 +0200 Subject: Formats --- training/callbacks/wandb_callbacks.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'training/callbacks') diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index aa72480..dc59f19 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + def __init__( + self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + ) -> None: self.log = log self.log_freq = log_freq + self.log_graph = log_graph @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + logger.watch( + model=trainer.model, + log=self.log, + log_freq=self.log_freq, + log_graph=self.log_graph, + ) class UploadConfigAsArtifact(Callback): -- cgit v1.2.3-70-g09d2