diff options
Diffstat (limited to 'training/callbacks/wandb.py')
-rw-r--r-- | training/callbacks/wandb.py | 34 |
1 files changed, 3 insertions, 31 deletions
diff --git a/training/callbacks/wandb.py b/training/callbacks/wandb.py index d9bb9b8..6adbebe 100644 --- a/training/callbacks/wandb.py +++ b/training/callbacks/wandb.py @@ -62,32 +62,7 @@ class UploadConfigAsArtifact(Callback): experiment.use_artifact(artifact) -class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoint to wandb as an artifact, at the end of a run.""" - - def __init__( - self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False - ) -> None: - self.ckpt_dir = Path(ckpt_dir) - self.upload_best_only = upload_best_only - - @rank_zero_only - def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads model checkpoint to W&B.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for ckpt in (self.ckpt_dir).rglob("*.ckpt"): - ckpts.add_file(ckpt) - - experiment.use_artifact(ckpts) - - -class ImageToCaptionLogger(Callback): +class ImageToCaption(Callback): """Logs the image and output caption.""" def __init__(self, num_samples: int = 8, on_train: bool = True) -> None: @@ -114,7 +89,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, + *args, ) -> None: """Logs predictions on validation batch end.""" if self.has_metrics(outputs): @@ -127,9 +102,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, *args, - # dataloader_idx: int, ) -> None: """Logs predictions on validation batch end.""" if self.has_metrics(outputs): @@ -142,8 +115,7 @@ class ImageToCaptionLogger(Callback): pl_module: LightningModule, outputs: dict, batch: Tuple[Tensor, Tensor], - batch_idx: int, - dataloader_idx: int, + *args, ) -> None: """Logs predictions on train batch end.""" if self.has_metrics(outputs): |