From 969e1d5e179d9c42ffae0c9b12c9bd3be6091360 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Jul 2021 22:58:07 +0200 Subject: Add wandb callbacks --- text_recognizer/callbacks/__init__.py | 0 text_recognizer/callbacks/wandb_callbacks.py | 129 +++++++++++++++++++++ training/conf/callbacks/checkpoint.yaml | 15 ++- training/conf/callbacks/default.yaml | 3 + training/conf/callbacks/early_stopping.yaml | 10 +- training/conf/callbacks/learning_rate_monitor.yaml | 6 +- training/conf/callbacks/swa.yaml | 13 +-- training/conf/callbacks/wandb.yaml | 20 ++++ 8 files changed, 175 insertions(+), 21 deletions(-) create mode 100644 text_recognizer/callbacks/__init__.py create mode 100644 text_recognizer/callbacks/wandb_callbacks.py create mode 100644 training/conf/callbacks/default.yaml create mode 100644 training/conf/callbacks/wandb.yaml diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py new file mode 100644 index 0000000..3936aaf --- /dev/null +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -0,0 +1,129 @@ +"""Weights and Biases callbacks.""" +from pathlib import Path +from typing import List + +import attr +import wandb +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import LoggerCollection, WandbLogger + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger: + """Safely get W&B logger from Trainer.""" + + if isinstance(trainer.logger, WandbLogger): + return trainer.logger + + if isinstance(trainer.logger, LoggerCollection): + for logger in trainer.logger: + if isinstance(logger, WandbLogger): + return logger + + raise Exception("Weight and Biases logger not found for some reason...") + + +@attr.s +class WatchModel(Callback): + """Make W&B watch the model at the beginning of the run.""" + + log: str = attr.ib(default="gradients") + log_freq: int = attr.ib(default=100) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Watches model weights with wandb.""" + logger = get_wandb_logger(trainer) + logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + + +@attr.s +class UploadCodeAsArtifact(Callback): + """Upload all *.py files to W&B as an artifact, at the beginning of the run.""" + + project_dir: Path = attr.ib(converter=Path) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads project code as an artifact.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + artifact = wandb.Artifact("project-source", type="code") + for filepath in self.project_dir.glob("**/*.py"): + artifact.add_file(filepath) + + experiment.use_artifact(artifact) + + +@attr.s +class UploadCheckpointAsArtifact(Callback): + """Upload checkpoint to wandb as an artifact, at the end of a run.""" + + ckpt_dir: Path = attr.ib(converter=Path) + upload_best_only: bool = attr.ib() + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Uploads model checkpoint to W&B.""" + logger = get_wandb_logger(trainer) + experiment = logger.experiment + ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") + + if self.upload_best_only: + ckpts.add_file(trainer.checkpoint_callback.best_model_path) + else: + for ckpt in (self.ckpt_dir).glob("**/*.ckpt"): + ckpts.add_file(ckpt) + + experiment.use_artifact(ckpts) + + +@attr.s +class LogTextPredictions(Callback): + """Logs a validation batch with image to text transcription.""" + + num_samples: int = attr.ib(default=8) + ready: bool = attr.ib(default=True) + + def __attrs_pre_init__(self): + super().__init__() + + def on_sanity_check_start( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Sets ready attribute.""" + self.ready = False + + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Logs predictions on validation epoch end.""" + if not self.ready: + return None + + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + # Get a validation batch from the validation dataloader. + samples = next(iter(trainer.datamodule.val_dataloader())) + imgs, labels = samples + + imgs = imgs.to(device=pl_module.device) + logits = pl_module(imgs) + + mapping = pl_module.mapping + experiment.log( + { + f"Images/{experiment.name}": [ + wandb.Image( + img, + caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}", + ) + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], + ) + ] + } + ) diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index f3beb1b..9216715 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -1,6 +1,9 @@ -checkpoint: - type: ModelCheckpoint - args: - monitor: val_loss - mode: min - save_last: true +model_checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: "val/loss" # name of the logged metric which determines when model is improving + save_top_k: 1 # save k best models (determined by above metric) + save_last: True # additionaly always save model from last epoch + mode: "min" # can be "max" or "min" + verbose: False + dirpath: "checkpoints/" + filename: "{epoch:02d}" diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml new file mode 100644 index 0000000..658fc03 --- /dev/null +++ b/training/conf/callbacks/default.yaml @@ -0,0 +1,3 @@ +defaults: + - checkpoint + - learning_rate_monitor diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml index ec671fd..4cd5aa1 100644 --- a/training/conf/callbacks/early_stopping.yaml +++ b/training/conf/callbacks/early_stopping.yaml @@ -1,6 +1,6 @@ early_stopping: - type: EarlyStopping - args: - monitor: val_loss - mode: min - patience: 10 + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: "val/loss" # name of the logged metric which determines when model is improving + patience: 16 # how many epochs of not improving until training stops + mode: "min" # can be "max" or "min" + min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml index 11a5ecf..4a14e1f 100644 --- a/training/conf/callbacks/learning_rate_monitor.yaml +++ b/training/conf/callbacks/learning_rate_monitor.yaml @@ -1,4 +1,4 @@ learning_rate_monitor: - type: LearningRateMonitor - args: - logging_interval: step + _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml index 92d9e6b..73f8c66 100644 --- a/training/conf/callbacks/swa.yaml +++ b/training/conf/callbacks/swa.yaml @@ -1,8 +1,7 @@ stochastic_weight_averaging: - type: StochasticWeightAveraging - args: - swa_epoch_start: 0.8 - swa_lrs: 0.05 - annealing_epochs: 10 - annealing_strategy: cos - device: null + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml new file mode 100644 index 0000000..2d56bfa --- /dev/null +++ b/training/conf/callbacks/wandb.yaml @@ -0,0 +1,20 @@ +defaults: + - default.yaml + +watch_model: + _target_: text_recognizer.callbacks.wandb_callbacks.WatchModel + log: "all" + log_freq: 100 + +upload_code_as_artifact: + _target_: text_recognizer.callbacks.wandb_callbacks.UploadCodeAsArtifact + project_dir: ${work_dir}/text_recognizer + +upload_ckpts_as_artifact: + _target_: text_recognizer.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: "checkpoints/" + upload_best_only: True + +log_text_predictions: + _target_: text_recognizer.callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 -- cgit v1.2.3-70-g09d2