summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:39 +0200
commit6968572c1a21394b88a29f675b17b9698784a898 (patch)
treed89d1c5c2ec331d38dcb5b6a2dbbd72c9e355b8a /training/callbacks
parent49ca6ade1a19f7f9c702171537fe4be0dfcda66d (diff)
Update training stuff
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb_callbacks.py23
1 files changed, 10 insertions, 13 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 7d4f48e..1c7955c 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -3,23 +3,17 @@ from pathlib import Path
import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
-from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+from pytorch_lightning.loggers import Logger, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
-from torch import nn
from torch.utils.data import DataLoader
-from torchvision.utils import make_grid
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get W&B logger from Trainer."""
- if isinstance(trainer.logger, WandbLogger):
- return trainer.logger
-
- if isinstance(trainer.logger, LoggerCollection):
- for logger in trainer.logger:
- if isinstance(logger, WandbLogger):
- return logger
+ for logger in trainer.loggers:
+ if isinstance(logger, WandbLogger):
+ return logger
raise Exception("Weight and Biases logger not found for some reason...")
@@ -28,9 +22,12 @@ class WatchModel(Callback):
"""Make W&B watch the model at the beginning of the run."""
def __init__(
- self, log: str = "gradients", log_freq: int = 100, log_graph: bool = False
+ self,
+ log_params: str = "gradients",
+ log_freq: int = 100,
+ log_graph: bool = False,
) -> None:
- self.log = log
+ self.log_params = log_params
self.log_freq = log_freq
self.log_graph = log_graph
@@ -40,7 +37,7 @@ class WatchModel(Callback):
logger = get_wandb_logger(trainer)
logger.watch(
model=trainer.model,
- log=self.log,
+ log=self.log_params,
log_freq=self.log_freq,
log_graph=self.log_graph,
)