summaryrefslogtreecommitdiff
path: root/training/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'training/callbacks')
-rw-r--r--training/callbacks/wandb.py (renamed from training/callbacks/wandb_callbacks.py)110
1 files changed, 55 insertions, 55 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb.py
index 1c7955c..d9bb9b8 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb.py
@@ -1,11 +1,12 @@
"""Weights and Biases callbacks."""
from pathlib import Path
+from typing import Tuple
import wandb
+from torch import Tensor
from pytorch_lightning import Callback, LightningModule, Trainer
-from pytorch_lightning.loggers import Logger, WandbLogger
+from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only
-from torch.utils.data import DataLoader
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -86,68 +87,67 @@ class UploadCheckpointsAsArtifact(Callback):
experiment.use_artifact(ckpts)
-class LogTextPredictions(Callback):
- """Logs a validation batch with image to text transcription."""
+class ImageToCaptionLogger(Callback):
+ """Logs the image and output caption."""
- def __init__(self, num_samples: int = 8) -> None:
+ def __init__(self, num_samples: int = 8, on_train: bool = True) -> None:
self.num_samples = num_samples
- self.ready = False
+ self.on_train = on_train
+ self._required_keys = ("predictions", "ground_truths")
- def _log_predictions(
+ def _log_captions(
+ self, trainer: Trainer, batch: Tuple[Tensor, Tensor], outputs: dict, key: str
+ ) -> None:
+ xs, _ = batch
+ preds, gts = outputs["predictions"], outputs["ground_truths"]
+ xs, preds, gts = (
+ list(xs[: self.num_samples]),
+ preds[: self.num_samples],
+ gts[: self.num_samples],
+ )
+ trainer.logger.log_image(key, xs, caption=preds)
+
+ @rank_zero_only
+ def on_train_batch_end(
self,
- stage: str,
trainer: Trainer,
pl_module: LightningModule,
- dataloader: DataLoader,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
) -> None:
- """Logs the predicted text contained in the images."""
- if not self.ready:
- return None
-
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
-
- # Get a validation batch from the validation dataloader.
- samples = next(iter(dataloader))
- imgs, labels = samples
-
- imgs = imgs.to(device=pl_module.device)
- logits = pl_module.predict(imgs)
-
- tokenizer = pl_module.tokenizer
- data = [
- wandb.Image(img, caption=tokenizer.decode(pred))
- for img, pred, label in zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
- ]
-
- experiment.log({f"HTR/{experiment.name}/{stage}": data})
+ """Logs predictions on validation batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "train/predictions")
- def on_sanity_check_start(
- self, trainer: Trainer, pl_module: LightningModule
+ @rank_zero_only
+ def on_validation_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
+ *args,
+ # dataloader_idx: int,
) -> None:
- """Sets ready attribute."""
- self.ready = False
-
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Start executing this callback only after all validation sanity checks end."""
- self.ready = True
+ """Logs predictions on validation batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "val/predictions")
- def on_validation_epoch_end(
- self, trainer: Trainer, pl_module: LightningModule
+ @rank_zero_only
+ def on_test_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: dict,
+ batch: Tuple[Tensor, Tensor],
+ batch_idx: int,
+ dataloader_idx: int,
) -> None:
- """Logs predictions on validation epoch end."""
- dataloader = trainer.datamodule.val_dataloader()
- self._log_predictions(
- stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
+ """Logs predictions on train batch end."""
+ if self.has_metrics(outputs):
+ self._log_captions(trainer, batch, outputs, "test/predictions")
- def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Logs predictions on train epoch end."""
- dataloader = trainer.datamodule.test_dataloader()
- self._log_predictions(
- stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
+ def has_metrics(self, outputs: dict) -> bool:
+ return all(k in outputs.keys() for k in self._required_keys)