diff options
Diffstat (limited to 'training/callbacks/wandb_callbacks.py')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index aa72480..dc59f19 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + def __init__( + self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + ) -> None: self.log = log self.log_freq = log_freq + self.log_graph = log_graph @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + logger.watch( + model=trainer.model, + log=self.log, + log_freq=self.log_freq, + log_graph=self.log_graph, + ) class UploadConfigAsArtifact(Callback): |