diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:07:54 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:07:54 +0200 |
commit | af1fccaa154a4455ea6b5d6c5b1da1e1e427edba (patch) | |
tree | 4bbfc3ffdc790406700f9d0cc2cd6b89227dabe2 | |
parent | 52a4291e47ca23c9c7a43541f03280ec92aafde3 (diff) |
Replace table with image and caption for w&b callback
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 40 |
1 files changed, 33 insertions, 7 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 68e4135..b23f720 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -5,6 +5,7 @@ import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch import nn from torch.utils.data import DataLoader from torchvision.utils import make_grid @@ -83,8 +84,9 @@ class UploadCheckpointsAsArtifact(Callback): class LogTextPredictions(Callback): """Logs a validation batch with image to text transcription.""" - def __init__(self, num_samples: int = 8) -> None: + def __init__(self, num_samples: int = 8, log_train: bool = False) -> None: self.num_samples = num_samples + self.log_train = log_train self.ready = False def _log_predictions( @@ -109,9 +111,8 @@ class LogTextPredictions(Callback): logits = pl_module(imgs) mapping = pl_module.mapping - columns = ["image", "prediction", "truth"] data = [ - [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + wandb.Image(img, caption=mapping.get_text(pred)) for img, pred, label in zip( imgs[: self.num_samples], logits[: self.num_samples], @@ -119,9 +120,7 @@ class LogTextPredictions(Callback): ) ] - experiment.log( - {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} - ) + experiment.log({f"HTR/{experiment.name}/{stage}": data}) def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule @@ -133,6 +132,15 @@ class LogTextPredictions(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + if not self.log_train: + return + dataloader = trainer.datamodule.train_dataloader() + self._log_predictions( + stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) + def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -153,9 +161,13 @@ class LogTextPredictions(Callback): class LogReconstuctedImages(Callback): """Log reconstructions of images.""" - def __init__(self, num_samples: int = 8) -> None: + def __init__( + self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False + ) -> None: self.num_samples = num_samples + self.log_train = log_train self.ready = False + self.sigmoid = nn.Sigmoid() if use_sigmoid else None def _log_reconstruction( self, @@ -177,6 +189,11 @@ class LogReconstuctedImages(Callback): imgs = imgs.to(device=pl_module.device) reconstructions = pl_module(imgs)[0] + reconstructions = ( + self.sigmoid(reconstructions) + if self.sigmoid is not None + else reconstructions + ) data = [ wandb.Image( @@ -200,6 +217,15 @@ class LogReconstuctedImages(Callback): """Start executing this callback only after all validation sanity checks end.""" self.ready = True + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + if not self.log_train: + return + dataloader = trainer.datamodule.train_dataloader() + self._log_reconstruction( + stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) + def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: |