From 9d8e73039b840bf3a2b52adcb7d279a2accd9790 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:07:08 +0200 Subject: Update wandb callbacks --- training/callbacks/wandb_callbacks.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index b23f720..c9d50d5 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -84,9 +84,8 @@ class UploadCheckpointsAsArtifact(Callback): class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - def __init__(self, num_samples: int = 8, log_train: bool = False) -> None: + def __init__(self, num_samples: int = 8) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False def _log_predictions( @@ -132,15 +131,6 @@ class LogTextPredictions(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_predictions( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -161,11 +151,8 @@ class LogTextPredictions(Callback): class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - def __init__( - self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False - ) -> None: + def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False self.sigmoid = nn.Sigmoid() if use_sigmoid else None @@ -217,15 +204,6 @@ class LogReconstuctedImages(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_reconstruction( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: -- cgit v1.2.3-70-g09d2