diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 26 |
1 files changed, 2 insertions, 24 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index b23f720..c9d50d5 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -84,9 +84,8 @@ class UploadCheckpointsAsArtifact(Callback): class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - def __init__(self, num_samples: int = 8, log_train: bool = False) -> None: + def __init__(self, num_samples: int = 8) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False def _log_predictions( @@ -132,15 +131,6 @@ class LogTextPredictions(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_predictions( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -161,11 +151,8 @@ class LogTextPredictions(Callback): class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - def __init__( - self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False - ) -> None: + def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False self.sigmoid = nn.Sigmoid() if use_sigmoid else None @@ -217,15 +204,6 @@ class LogReconstuctedImages(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_reconstruction( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: |