diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:12:54 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-18 18:12:54 +0200 |
commit | 8eec5f3d4ce41c2ba30f0f1ab3f506fc3c0dd263 (patch) | |
tree | 19244a467e90ed8b919cdf2912d9e1b141ea1801 /training | |
parent | 2cc6aa059139b57057609817913ad515063c2eab (diff) |
Formats
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index aa72480..dc59f19 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -27,15 +27,23 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" - def __init__(self, log: str = "gradients", log_freq: int = 100) -> None: + def __init__( + self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + ) -> None: self.log = log self.log_freq = log_freq + self.log_graph = log_graph @rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Watches model weights with wandb.""" logger = get_wandb_logger(trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + logger.watch( + model=trainer.model, + log=self.log, + log_freq=self.log_freq, + log_graph=self.log_graph, + ) class UploadConfigAsArtifact(Callback): |