summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py26
1 files changed, 2 insertions, 24 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index b23f720..c9d50d5 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -84,9 +84,8 @@ class UploadCheckpointsAsArtifact(Callback):
class LogTextPredictions(Callback):
"""Logs a validation batch with image to text transcription."""
- def __init__(self, num_samples: int = 8, log_train: bool = False) -> None:
+ def __init__(self, num_samples: int = 8) -> None:
self.num_samples = num_samples
- self.log_train = log_train
self.ready = False
def _log_predictions(
@@ -132,15 +131,6 @@ class LogTextPredictions(Callback):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Logs predictions on train epoch end."""
- if not self.log_train:
- return
- dataloader = trainer.datamodule.train_dataloader()
- self._log_predictions(
- stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
-
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
@@ -161,11 +151,8 @@ class LogTextPredictions(Callback):
class LogReconstuctedImages(Callback):
"""Log reconstructions of images."""
- def __init__(
- self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False
- ) -> None:
+ def __init__(self, num_samples: int = 8, use_sigmoid: bool = False) -> None:
self.num_samples = num_samples
- self.log_train = log_train
self.ready = False
self.sigmoid = nn.Sigmoid() if use_sigmoid else None
@@ -217,15 +204,6 @@ class LogReconstuctedImages(Callback):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
- def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
- """Logs predictions on train epoch end."""
- if not self.log_train:
- return
- dataloader = trainer.datamodule.train_dataloader()
- self._log_reconstruction(
- stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader
- )
-
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None: