diff options
Diffstat (limited to 'training/callbacks')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 7d4f48e..1c7955c 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -3,23 +3,17 @@ from pathlib import Path import wandb from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.loggers import Logger, WandbLogger from pytorch_lightning.utilities import rank_zero_only -from torch import nn from torch.utils.data import DataLoader -from torchvision.utils import make_grid def get_wandb_logger(trainer: Trainer) -> WandbLogger: """Safely get W&B logger from Trainer.""" - if isinstance(trainer.logger, WandbLogger): - return trainer.logger - - if isinstance(trainer.logger, LoggerCollection): - for logger in trainer.logger: - if isinstance(logger, WandbLogger): - return logger + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + return logger raise Exception("Weight and Biases logger not found for some reason...") @@ -28,9 +22,12 @@ class WatchModel(Callback): """Make W&B watch the model at the beginning of the run.""" def __init__( - self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False + self, + log_params: str = "gradients", + log_freq: int = 100, + log_graph: bool = False, ) -> None: - self.log = log + self.log_params = log_params self.log_freq = log_freq self.log_graph = log_graph @@ -40,7 +37,7 @@ class WatchModel(Callback): logger = get_wandb_logger(trainer) logger.watch( model=trainer.model, - log=self.log, + log=self.log_params, log_freq=self.log_freq, log_graph=self.log_graph, ) |