diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:07:08 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:07:08 +0200 |
commit | 9d8e73039b840bf3a2b52adcb7d279a2accd9790 (patch) | |
tree | 5a1aff1f0646ae55e56ae8d264b4b96eac3c27a9 /training/callbacks | |
parent | b69254ce3135c112e29f7f1c986b7f0817da0c33 (diff) |
Update wandb callbacks
Diffstat (limited to 'training/callbacks')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 26 |
1 files changed, 2 insertions, 24 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index b23f720..c9d50d5 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -84,9 +84,8 @@ class UploadCheckpointsAsArtifact(Callback): class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - def __init__(self, num_samples: int = 8, log_train: bool = False) -> None: + def __init__(self, num_samples: int = 8) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False def _log_predictions( @@ -132,15 +131,6 @@ class LogTextPredictions(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_predictions( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -161,11 +151,8 @@ class LogTextPredictions(Callback): class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - def __init__( - self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False - ) -> None: + def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None: self.num_samples = num_samples - self.log_train = log_train self.ready = False self.sigmoid = nn.Sigmoid() if use_sigmoid else None @@ -217,15 +204,6 @@ class LogReconstuctedImages(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - if not self.log_train: - return - dataloader = trainer.datamodule.train_dataloader() - self._log_reconstruction( - stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader - ) - def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: |