From abc2d60d69d115cdb34615d8bcb6c03ab6357141 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 2 Sep 2023 01:53:37 +0200 Subject: Refactor wandb callbacks --- training/callbacks/wandb.py | 153 ++++++++++++++++++++++++++++++++++ training/callbacks/wandb_callbacks.py | 153 ---------------------------------- 2 files changed, 153 insertions(+), 153 deletions(-) create mode 100644 training/callbacks/wandb.py delete mode 100644 training/callbacks/wandb_callbacks.py (limited to 'training/callbacks') diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py new file mode 100644 index 0000000..d9bb9b8 --- /dev/null +++ b/training/callbacks/wandb.py @@ -0,0 +1,153 @@ +"""Weights and Biases callbacks.""" +from pathlib import Path +from typing import Tuple + +import wandb +from torch import Tensor +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities import rank_zero_only + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger: + """Safely get W&B logger from Trainer.""" + + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + return logger + + raise Exception("Weight and Biases logger not found for some reason...") + + +class WatchModel(Callback): + """Make W&B watch the model at the beginning of the run.""" + + def __init__( + self, + log_params: str = "gradients", + log_freq: int = 100, + log_graph: bool = False, + ) -> None: + self.log_params = log_params + self.log_freq = log_freq + self.log_graph = log_graph + + @rank_zero_only + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Watches model weights with wandb.""" + logger = get_wandb_logger(trainer) + logger.watch( + model=trainer.model, + log=self.log_params, + log_freq=self.log_freq, + log_graph=self.log_graph, + ) + + +class UploadConfigAsArtifact(Callback): + """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" + + def __init__(self) -> None: + self.config_dir = Path(".hydra/") + + @rank_zero_only + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads project code as an artifact.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + artifact = wandb.Artifact("experiment-config", type="config") + for filepath in self.config_dir.rglob("*.yaml"): + artifact.add_file(str(filepath)) + + experiment.use_artifact(artifact) + + +class UploadCheckpointsAsArtifact(Callback): + """Upload checkpoint to wandb as an artifact, at the end of a run.""" + + def __init__( + self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False + ) -> None: + self.ckpt_dir = Path(ckpt_dir) + self.upload_best_only = upload_best_only + + @rank_zero_only + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads model checkpoint to W&B.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") + + if self.upload_best_only: + ckpts.add_file(trainer.checkpoint_callback.best_model_path) + else: + for ckpt in (self.ckpt_dir).rglob("*.ckpt"): + ckpts.add_file(ckpt) + + experiment.use_artifact(ckpts) + + +class ImageToCaptionLogger(Callback): + """Logs the image and output caption.""" + + def __init__(self, num_samples: int = 8, on_train: bool = True) -> None: + self.num_samples = num_samples + self.on_train = on_train + self._required_keys = ("predictions", "ground_truths") + + def _log_captions( + self, trainer: Trainer, batch: Tuple[Tensor, Tensor], outputs: dict, key: str + ) -> None: + xs, _ = batch + preds, gts = outputs["predictions"], outputs["ground_truths"] + xs, preds, gts = ( + list(xs[: self.num_samples]), + preds[: self.num_samples], + gts[: self.num_samples], + ) + trainer.logger.log_image(key, xs, caption=preds) + + @rank_zero_only + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, + ) -> None: + """Logs predictions on validation batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "train/predictions") + + @rank_zero_only + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, + *args, + # dataloader_idx: int, + ) -> None: + """Logs predictions on validation batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "val/predictions") + + @rank_zero_only + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, + dataloader_idx: int, + ) -> None: + """Logs predictions on train batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "test/predictions") + + def has_metrics(self, outputs: dict) -> bool: + return all(k in outputs.keys() for k in self._required_keys) diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py deleted file mode 100644 index 1c7955c..0000000 --- a/training/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Weights and Biases callbacks.""" -from pathlib import Path - -import wandb -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import Logger, WandbLogger -from pytorch_lightning.utilities import rank_zero_only -from torch.utils.data import DataLoader - - -def get_wandb_logger(trainer: Trainer) -> WandbLogger: - """Safely get W&B logger from Trainer.""" - - for logger in trainer.loggers: - if isinstance(logger, WandbLogger): - return logger - - raise Exception("Weight and Biases logger not found for some reason...") - - -class WatchModel(Callback): - """Make W&B watch the model at the beginning of the run.""" - - def __init__( - self, - log_params: str = "gradients", - log_freq: int = 100, - log_graph: bool = False, - ) -> None: - self.log_params = log_params - self.log_freq = log_freq - self.log_graph = log_graph - - @rank_zero_only - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Watches model weights with wandb.""" - logger = get_wandb_logger(trainer) - logger.watch( - model=trainer.model, - log=self.log_params, - log_freq=self.log_freq, - log_graph=self.log_graph, - ) - - -class UploadConfigAsArtifact(Callback): - """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" - - def __init__(self) -> None: - self.config_dir = Path(".hydra/") - - @rank_zero_only - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads project code as an artifact.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - artifact = wandb.Artifact("experiment-config", type="config") - for filepath in self.config_dir.rglob("*.yaml"): - artifact.add_file(str(filepath)) - - experiment.use_artifact(artifact) - - -class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoint to wandb as an artifact, at the end of a run.""" - - def __init__( - self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False - ) -> None: - self.ckpt_dir = Path(ckpt_dir) - self.upload_best_only = upload_best_only - - @rank_zero_only - def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads model checkpoint to W&B.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for ckpt in (self.ckpt_dir).rglob("*.ckpt"): - ckpts.add_file(ckpt) - - experiment.use_artifact(ckpts) - - -class LogTextPredictions(Callback): - """Logs a validation batch with image to text transcription.""" - - def __init__(self, num_samples: int = 8) -> None: - self.num_samples = num_samples - self.ready = False - - def _log_predictions( - self, - stage: str, - trainer: Trainer, - pl_module: LightningModule, - dataloader: DataLoader, - ) -> None: - """Logs the predicted text contained in the images.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(dataloader)) - imgs, labels = samples - - imgs = imgs.to(device=pl_module.device) - logits = pl_module.predict(imgs) - - tokenizer = pl_module.tokenizer - data = [ - wandb.Image(img, caption=tokenizer.decode(pred)) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - - experiment.log({f"HTR/{experiment.name}/{stage}": data}) - - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs predictions on validation epoch end.""" - dataloader = trainer.datamodule.val_dataloader() - self._log_predictions( - stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - - def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - dataloader = trainer.datamodule.test_dataloader() - self._log_predictions( - stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) -- cgit v1.2.3-70-g09d2