diff options
Diffstat (limited to 'text_recognizer/callbacks')
-rw-r--r-- | text_recognizer/callbacks/wandb_callbacks.py | 95 |
1 files changed, 84 insertions, 11 deletions
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 4186b4a..d9d81f6 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -93,6 +93,40 @@ class LogTextPredictions(Callback): def __attrs_pre_init__(self) -> None: super().__init__() + def _log_predictions( + stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the predicted text contained in the images.""" + if not self.ready: + return None + + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + # Get a validation batch from the validation dataloader. + samples = next(iter(trainer.datamodule.val_dataloader())) + imgs, labels = samples + + imgs = imgs.to(device=pl_module.device) + logits = pl_module(imgs) + + mapping = pl_module.mapping + experiment.log( + { + f"OCR/{experiment.name}/{stage}": [ + wandb.Image( + img, + caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", + ) + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ] + } + ) + def on_sanity_check_start( self, trainer: Trainer, pl_module: LightningModule ) -> None: @@ -107,6 +141,27 @@ class LogTextPredictions(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" + self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + + +@attr.s +class LogReconstuctedImages(Callback): + """Log reconstructions of images.""" + + num_samples: int = attr.ib(default=8) + ready: bool = attr.ib(default=True) + + def __attrs_pre_init__(self) -> None: + super().__init__() + + def _log_reconstruction( + self, stage: str, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs the reconstructions.""" if not self.ready: return None @@ -115,24 +170,42 @@ class LogTextPredictions(Callback): # Get a validation batch from the validation dataloader. samples = next(iter(trainer.datamodule.val_dataloader())) - imgs, labels = samples + imgs, _ = samples imgs = imgs.to(device=pl_module.device) - logits = pl_module(imgs) + reconstructions = pl_module(imgs) - mapping = pl_module.mapping experiment.log( { - f"Images/{experiment.name}": [ - wandb.Image( - img, - caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", - ) - for img, pred, label in zip( + f"Reconstructions/{experiment.name}/{stage}": [ + [ + wandb.Image(img), + wandb.Image(rec), + ] + for img, rec in zip( imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], + reconstructions[: self.num_samples], ) ] } ) + + def on_sanity_check_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Sets ready attribute.""" + self.ready = False + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs predictions on validation epoch end.""" + self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Logs predictions on train epoch end.""" + self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) |