summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py40
1 files changed, 33 insertions, 7 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 68e4135..b23f720 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -5,6 +5,7 @@ import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch import nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
@@ -83,8 +84,9 @@ class UploadCheckpointsAsArtifact(Callback):
class LogTextPredictions(Callback):
"""Logs a validation batch with image to text transcription."""
- def __init__(self, num_samples: int = 8) -> None:
+ def __init__(self, num_samples: int = 8, log_train: bool = False) -> None:
self.num_samples = num_samples
+ self.log_train = log_train
self.ready = False
def _log_predictions(
@@ -109,9 +111,8 @@ class LogTextPredictions(Callback):
logits = pl_module(imgs)
mapping = pl_module.mapping
- columns = ["image", "prediction", "truth"]
data = [
- [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ wandb.Image(img, caption=mapping.get_text(pred))
for img, pred, label in zip(
imgs[: self.num_samples],
logits[: self.num_samples],
@@ -119,9 +120,7 @@ class LogTextPredictions(Callback):
)
]
- experiment.log(
- {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
- )
+ experiment.log({f"HTR/{experiment.name}/{stage}": data})
def on_sanity_check_start(
self, trainer: Trainer, pl_module: LightningModule
@@ -133,6 +132,15 @@ class LogTextPredictions(Callback):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ if not self.log_train:
+ return
+ dataloader = trainer.datamodule.train_dataloader()
+ self._log_predictions(
+ stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
+
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
@@ -153,9 +161,13 @@ class LogTextPredictions(Callback):
class LogReconstuctedImages(Callback):
"""Log reconstructions of images."""
- def __init__(self, num_samples: int = 8) -> None:
+ def __init__(
+ self, num_samples: int = 8, log_train: bool = False, use_sigmoid: bool = False
+ ) -> None:
self.num_samples = num_samples
+ self.log_train = log_train
self.ready = False
+ self.sigmoid = nn.Sigmoid() if use_sigmoid else None
def _log_reconstruction(
self,
@@ -177,6 +189,11 @@ class LogReconstuctedImages(Callback):
imgs = imgs.to(device=pl_module.device)
reconstructions = pl_module(imgs)[0]
+ reconstructions = (
+ self.sigmoid(reconstructions)
+ if self.sigmoid is not None
+ else reconstructions
+ )
data = [
wandb.Image(
@@ -200,6 +217,15 @@ class LogReconstuctedImages(Callback):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ if not self.log_train:
+ return
+ dataloader = trainer.datamodule.train_dataloader()
+ self._log_reconstruction(
+ stage="train", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
+
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None: