diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:37 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:37 +0200 |
commit | abc2d60d69d115cdb34615d8bcb6c03ab6357141 (patch) | |
tree | 74f8536f7ca072f917fd924d2528ccaf0c273b49 /training/callbacks/wandb_callbacks.py | |
parent | 617bf7f0285090b85817a398ef4bb871d4f616e9 (diff) |
Refactor wandb callbacks
Diffstat (limited to 'training/callbacks/wandb_callbacks.py')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 153 |
1 files changed, 0 insertions, 153 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py deleted file mode 100644 index 1c7955c..0000000 --- a/training/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Weights and Biases callbacks.""" -from pathlib import Path - -import wandb -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import Logger, WandbLogger -from pytorch_lightning.utilities import rank_zero_only -from torch.utils.data import DataLoader - - -def get_wandb_logger(trainer: Trainer) -> WandbLogger: - """Safely get W&B logger from Trainer.""" - - for logger in trainer.loggers: - if isinstance(logger, WandbLogger): - return logger - - raise Exception("Weight and Biases logger not found for some reason...") - - -class WatchModel(Callback): - """Make W&B watch the model at the beginning of the run.""" - - def __init__( - self, - log_params: str = "gradients", - log_freq: int = 100, - log_graph: bool = False, - ) -> None: - self.log_params = log_params - self.log_freq = log_freq - self.log_graph = log_graph - - @rank_zero_only - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Watches model weights with wandb.""" - logger = get_wandb_logger(trainer) - logger.watch( - model=trainer.model, - log=self.log_params, - log_freq=self.log_freq, - log_graph=self.log_graph, - ) - - -class UploadConfigAsArtifact(Callback): - """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" - - def __init__(self) -> None: - self.config_dir = Path(".hydra/") - - @rank_zero_only - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads project code as an artifact.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - artifact = wandb.Artifact("experiment-config", type="config") - for filepath in self.config_dir.rglob("*.yaml"): - artifact.add_file(str(filepath)) - - experiment.use_artifact(artifact) - - -class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoint to wandb as an artifact, at the end of a run.""" - - def __init__( - self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False - ) -> None: - self.ckpt_dir = Path(ckpt_dir) - self.upload_best_only = upload_best_only - - @rank_zero_only - def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads model checkpoint to W&B.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for ckpt in (self.ckpt_dir).rglob("*.ckpt"): - ckpts.add_file(ckpt) - - experiment.use_artifact(ckpts) - - -class LogTextPredictions(Callback): - """Logs a validation batch with image to text transcription.""" - - def __init__(self, num_samples: int = 8) -> None: - self.num_samples = num_samples - self.ready = False - - def _log_predictions( - self, - stage: str, - trainer: Trainer, - pl_module: LightningModule, - dataloader: DataLoader, - ) -> None: - """Logs the predicted text contained in the images.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(dataloader)) - imgs, labels = samples - - imgs = imgs.to(device=pl_module.device) - logits = pl_module.predict(imgs) - - tokenizer = pl_module.tokenizer - data = [ - wandb.Image(img, caption=tokenizer.decode(pred)) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - - experiment.log({f"HTR/{experiment.name}/{stage}": data}) - - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs predictions on validation epoch end.""" - dataloader = trainer.datamodule.val_dataloader() - self._log_predictions( - stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - - def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - dataloader = trainer.datamodule.test_dataloader() - self._log_predictions( - stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) |