summaryrefslogtreecommitdiff
path: root/training/callbacks/wandb.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/callbacks/wandb.py')
-rw-r--r--training/callbacks/wandb.py34
1 files changed, 3 insertions, 31 deletions
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py
index d9bb9b8..6adbebe 100644
--- a/training/callbacks/wandb.py
+++ b/training/callbacks/wandb.py
@@ -62,32 +62,7 @@ class UploadConfigAsArtifact(Callback):
experiment.use_artifact(artifact)
-class UploadCheckpointsAsArtifact(Callback):
- """Upload checkpoint to wandb as an artifact, at the end of a run."""
-
- def __init__(
- self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False
- ) -> None:
- self.ckpt_dir = Path(ckpt_dir)
- self.upload_best_only = upload_best_only
-
- @rank_zero_only
- def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Uploads model checkpoint to W&B."""
- logger = get_wandb_logger(trainer)
- experiment = logger.experiment
- ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
-
- if self.upload_best_only:
- ckpts.add_file(trainer.checkpoint_callback.best_model_path)
- else:
- for ckpt in (self.ckpt_dir).rglob("*.ckpt"):
- ckpts.add_file(ckpt)
-
- experiment.use_artifact(ckpts)
-
-
-class ImageToCaptionLogger(Callback):
+class ImageToCaption(Callback):
"""Logs the image and output caption."""
def __init__(self, num_samples: int = 8, on_train: bool = True) -> None:
@@ -114,7 +89,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
+ *args,
) -> None:
"""Logs predictions on validation batch end."""
if self.has_metrics(outputs):
@@ -127,9 +102,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
*args,
- # dataloader_idx: int,
) -> None:
"""Logs predictions on validation batch end."""
if self.has_metrics(outputs):
@@ -142,8 +115,7 @@ class ImageToCaptionLogger(Callback):
pl_module: LightningModule,
outputs: dict,
batch: Tuple[Tensor, Tensor],
- batch_idx: int,
- dataloader_idx: int,
+ *args,
) -> None:
"""Logs predictions on train batch end."""
if self.has_metrics(outputs):