summaryrefslogtreecommitdiff
path: root/training/callbacks/wandb_callbacks.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:53:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:53:37 +0200
commitabc2d60d69d115cdb34615d8bcb6c03ab6357141 (patch)
tree74f8536f7ca072f917fd924d2528ccaf0c273b49 /training/callbacks/wandb_callbacks.py
parent617bf7f0285090b85817a398ef4bb871d4f616e9 (diff)
Refactor wandb callbacks
Diffstat (limited to 'training/callbacks/wandb_callbacks.py')
-rw-r--r--training/callbacks/wandb_callbacks.py153
1 files changed, 0 insertions, 153 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
deleted file mode 100644
index 1c7955c..0000000
--- a/training/callbacks/wandb_callbacks.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Weights and Biases callbacks."""
-from pathlib import Path
-
-import wandb
-from pytorch_lightning import Callback, LightningModule, Trainer
-from pytorch_lightning.loggers import Logger, WandbLogger
-from pytorch_lightning.utilities import rank_zero_only
-from torch.utils.data import DataLoader
-
-
-def get_wandb_logger(trainer: Trainer) -> WandbLogger:
- """Safely get W&B logger from Trainer."""
-
- for logger in trainer.loggers:
- if isinstance(logger, WandbLogger):
- return logger
-
- raise Exception("Weight and Biases logger not found for some reason...")
-
-
-class WatchModel(Callback):
- """Make W&B watch the model at the beginning of the run."""
-
- def __init__(
- self,
- log_params: str = "gradients",
- log_freq: int = 100,
- log_graph: bool = False,
- ) -> None:
- self.log_params = log_params
- self.log_freq = log_freq
- self.log_graph = log_graph
-
- @rank_zero_only
- def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Watches model weights with wandb."""
- logger = get_wandb_logger(trainer)
- logger.watch(
- model=trainer.model,
- log=self.log_params,
- log_freq=self.log_freq,
- log_graph=self.log_graph,
- )
-
-
-class UploadConfigAsArtifact(Callback):
- """Upload all *.py files to W&B as an artifact, at the beginning of the run."""
-
- def __init__(self) -> None:
- self.config_dir = Path(".hydra/")
-
- @rank_zero_only
- def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Uploads project code as an artifact."""
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
- artifact = wandb.Artifact("experiment-config", type="config")
- for filepath in self.config_dir.rglob("*.yaml"):
- artifact.add_file(str(filepath))
-
- experiment.use_artifact(artifact)
-
-
-class UploadCheckpointsAsArtifact(Callback):
- """Upload checkpoint to wandb as an artifact, at the end of a run."""
-
- def __init__(
- self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
- ) -> None:
- self.ckpt_dir = Path(ckpt_dir)
- self.upload_best_only = upload_best_only
-
- @rank_zero_only
- def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Uploads model checkpoint to W&B."""
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
- ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
-
- if self.upload_best_only:
- ckpts.add_file(trainer.checkpoint_callback.best_model_path)
- else:
- for ckpt in (self.ckpt_dir).rglob("*.ckpt"):
- ckpts.add_file(ckpt)
-
- experiment.use_artifact(ckpts)
-
-
-class LogTextPredictions(Callback):
- """Logs a validation batch with image to text transcription."""
-
- def __init__(self, num_samples: int = 8) -> None:
- self.num_samples = num_samples
- self.ready = False
-
- def _log_predictions(
- self,
- stage: str,
- trainer: Trainer,
- pl_module: LightningModule,
- dataloader: DataLoader,
- ) -> None:
- """Logs the predicted text contained in the images."""
- if not self.ready:
- return None
-
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
-
- # Get a validation batch from the validation dataloader.
- samples = next(iter(dataloader))
- imgs, labels = samples
-
- imgs = imgs.to(device=pl_module.device)
- logits = pl_module.predict(imgs)
-
- tokenizer = pl_module.tokenizer
- data = [
- wandb.Image(img, caption=tokenizer.decode(pred))
- for img, pred, label in zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
- ]
-
- experiment.log({f"HTR/{experiment.name}/{stage}": data})
-
- def on_sanity_check_start(
- self, trainer: Trainer, pl_module: LightningModule
- ) -> None:
- """Sets ready attribute."""
- self.ready = False
-
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
-
- def on_validation_epoch_end(
- self, trainer: Trainer, pl_module: LightningModule
- ) -> None:
- """Logs predictions on validation epoch end."""
- dataloader = trainer.datamodule.val_dataloader()
- self._log_predictions(
- stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
-
- def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Logs predictions on train epoch end."""
- dataloader = trainer.datamodule.test_dataloader()
- self._log_predictions(
- stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )