From 97092991618d8c7ddfc477f2ceba0749071587d0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 8 Jul 2021 22:26:30 +0200 Subject: Move callbacks to training folder, refactor --- text_recognizer/callbacks/__init__.py | 1 - text_recognizer/callbacks/wandb_callbacks.py | 211 --------------------------- 2 files changed, 212 deletions(-) delete mode 100644 text_recognizer/callbacks/__init__.py delete mode 100644 text_recognizer/callbacks/wandb_callbacks.py diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py deleted file mode 100644 index 82d8ce3..0000000 --- a/text_recognizer/callbacks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for PyTorch Lightning callbacks.""" diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py deleted file mode 100644 index d9d81f6..0000000 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Weights and Biases callbacks.""" -from pathlib import Path -from typing import List - -import attr -import wandb -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import LoggerCollection, WandbLogger - - -def get_wandb_logger(trainer: Trainer) -> WandbLogger: - """Safely get W&B logger from Trainer.""" - - if isinstance(trainer.logger, WandbLogger): - return trainer.logger - - if isinstance(trainer.logger, LoggerCollection): - for logger in trainer.logger: - if isinstance(logger, WandbLogger): - return logger - - raise Exception("Weight and Biases logger not found for some reason...") - - -@attr.s -class WatchModel(Callback): - """Make W&B watch the model at the beginning of the run.""" - - log: str = attr.ib(default="gradients") - log_freq: int = attr.ib(default=100) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Watches model weights with wandb.""" - logger = get_wandb_logger(trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) - - -@attr.s -class UploadCodeAsArtifact(Callback): - """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" - - project_dir: Path = attr.ib(converter=Path) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads project code as an artifact.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - artifact = wandb.Artifact("project-source", type="code") - for filepath in self.project_dir.glob("**/*.py"): - artifact.add_file(filepath) - - experiment.use_artifact(artifact) - - -@attr.s -class UploadCheckpointAsArtifact(Callback): - """Upload checkpoint to wandb as an artifact, at the end of a run.""" - - ckpt_dir: Path = attr.ib(converter=Path) - upload_best_only: bool = attr.ib() - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Uploads model checkpoint to W&B.""" - logger = get_wandb_logger(trainer) - experiment = logger.experiment - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for ckpt in (self.ckpt_dir).glob("**/*.ckpt"): - ckpts.add_file(ckpt) - - experiment.use_artifact(ckpts) - - -@attr.s -class LogTextPredictions(Callback): - """Logs a validation batch with image to text transcription.""" - - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def _log_predictions( - stage: str, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs the predicted text contained in the images.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) - imgs, labels = samples - - imgs = imgs.to(device=pl_module.device) - logits = pl_module(imgs) - - mapping = pl_module.mapping - experiment.log( - { - f"OCR/{experiment.name}/{stage}": [ - wandb.Image( - img, - caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", - ) - for img, pred, label in zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) - ] - } - ) - - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs predictions on validation epoch end.""" - self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) - - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) - - -@attr.s -class LogReconstuctedImages(Callback): - """Log reconstructions of images.""" - - num_samples: int = attr.ib(default=8) - ready: bool = attr.ib(default=True) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def _log_reconstruction( - self, stage: str, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs the reconstructions.""" - if not self.ready: - return None - - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) - imgs, _ = samples - - imgs = imgs.to(device=pl_module.device) - reconstructions = pl_module(imgs) - - experiment.log( - { - f"Reconstructions/{experiment.name}/{stage}": [ - [ - wandb.Image(img), - wandb.Image(rec), - ] - for img, rec in zip( - imgs[: self.num_samples], - reconstructions[: self.num_samples], - ) - ] - } - ) - - def on_sanity_check_start( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Sets ready attribute.""" - self.ready = False - - def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: LightningModule - ) -> None: - """Logs predictions on validation epoch end.""" - self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) - - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - """Logs predictions on train epoch end.""" - self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) -- cgit v1.2.3-70-g09d2