diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:37 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-09-02 01:53:37 +0200 |
commit | abc2d60d69d115cdb34615d8bcb6c03ab6357141 (patch) | |
tree | 74f8536f7ca072f917fd924d2528ccaf0c273b49 /training/callbacks | |
parent | 617bf7f0285090b85817a398ef4bb871d4f616e9 (diff) |
Refactor wandb callbacks
Diffstat (limited to 'training/callbacks')
-rw-r--r-- | training/callbacks/wandb.py (renamed from training/callbacks/wandb_callbacks.py) | 110 |
1 files changed, 55 insertions, 55 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb.py index 1c7955c..d9bb9b8 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb.py @@ -1,11 +1,12 @@ """Weights and Biases callbacks.""" from pathlib import Path +from typing import Tuple import wandb +from torch import Tensor from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import Logger, WandbLogger +from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities import rank_zero_only -from torch.utils.data import DataLoader def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -86,68 +87,67 @@ class UploadCheckpointsAsArtifact(Callback): experiment.use_artifact(ckpts) -class LogTextPredictions(Callback): - """Logs a validation batch with image to text transcription.""" +class ImageToCaptionLogger(Callback): + """Logs the image and output caption.""" - def __init__(self, num_samples: int = 8) -> None: + def __init__(self, num_samples: int = 8, on_train: bool = True) -> None: self.num_samples = num_samples - self.ready = False + self.on_train = on_train + self._required_keys = ("predictions", "ground_truths") - def _log_predictions( + def _log_captions( + self, trainer: Trainer, batch: Tuple[Tensor, Tensor], outputs: dict, key: str + ) -> None: + xs, _ = batch + preds, gts = outputs["predictions"], outputs["ground_truths"] + xs, preds, gts = ( + list(xs[: self.num_samples]), + preds[: self.num_samples], + gts[: self.num_samples], + ) + trainer.logger.log_image(key, xs, caption=preds) + + @rank_zero_only + def on_train_batch_end( self, - stage: str, trainer: Trainer, pl_module: LightningModule, - dataloader: DataLoader, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, ) -> None: - """Logs the predicted text contained in the images.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(dataloader)) - imgs, labels = samples - - imgs = imgs.to(device=pl_module.device) - logits = pl_module.predict(imgs) - - tokenizer = pl_module.tokenizer - data = [ - wandb.Image(img, caption=tokenizer.decode(pred)) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - - experiment.log({f"HTR/{experiment.name}/{stage}": data}) + """Logs predictions on validation batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "train/predictions") - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule + @rank_zero_only + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, + *args, + # dataloader_idx: int, ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True + """Logs predictions on validation batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "val/predictions") - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule + @rank_zero_only + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: dict, + batch: Tuple[Tensor, Tensor], + batch_idx: int, + dataloader_idx: int, ) -> None: - """Logs predictions on validation epoch end.""" - dataloader = trainer.datamodule.val_dataloader() - self._log_predictions( - stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) + """Logs predictions on train batch end.""" + if self.has_metrics(outputs): + self._log_captions(trainer, batch, outputs, "test/predictions") - def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - dataloader = trainer.datamodule.test_dataloader() - self._log_predictions( - stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) + def has_metrics(self, outputs: dict) -> bool: + return all(k in outputs.keys() for k in self._required_keys) |