summaryrefslogtreecommitdiff
path: root/text_recognizer/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/callbacks')
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py95
1 files changed, 84 insertions, 11 deletions
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 4186b4a..d9d81f6 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -93,6 +93,40 @@ class LogTextPredictions(Callback):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ def _log_predictions(
+ stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the predicted text contained in the images."""
+ if not self.ready:
+ return None
+
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+
+ # Get a validation batch from the validation dataloader.
+ samples = next(iter(trainer.datamodule.val_dataloader()))
+ imgs, labels = samples
+
+ imgs = imgs.to(device=pl_module.device)
+ logits = pl_module(imgs)
+
+ mapping = pl_module.mapping
+ experiment.log(
+ {
+ f"OCR/{experiment.name}/{stage}": [
+ wandb.Image(
+ img,
+ caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
+ )
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ ]
+ }
+ )
+
def on_sanity_check_start(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
@@ -107,6 +141,27 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
+ self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+
+
+@attr.s
+class LogReconstuctedImages(Callback):
+ """Log reconstructions of images."""
+
+ num_samples: int = attr.ib(default=8)
+ ready: bool = attr.ib(default=True)
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def _log_reconstruction(
+ self, stage: str, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs the reconstructions."""
if not self.ready:
return None
@@ -115,24 +170,42 @@ class LogTextPredictions(Callback):
# Get a validation batch from the validation dataloader.
samples = next(iter(trainer.datamodule.val_dataloader()))
- imgs, labels = samples
+ imgs, _ = samples
imgs = imgs.to(device=pl_module.device)
- logits = pl_module(imgs)
+ reconstructions = pl_module(imgs)
- mapping = pl_module.mapping
experiment.log(
{
- f"Images/{experiment.name}": [
- wandb.Image(
- img,
- caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
- )
- for img, pred, label in zip(
+ f"Reconstructions/{experiment.name}/{stage}": [
+ [
+ wandb.Image(img),
+ wandb.Image(rec),
+ ]
+ for img, rec in zip(
imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
+ reconstructions[: self.num_samples],
)
]
}
)
+
+ def on_sanity_check_start(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Sets ready attribute."""
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_epoch_end(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs predictions on validation epoch end."""
+ self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Logs predictions on train epoch end."""
+ self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)