summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-04 22:58:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-04 22:58:07 +0200
commit969e1d5e179d9c42ffae0c9b12c9bd3be6091360 (patch)
treeab849c38bc9b863afad85fd04d6f618031000e6f
parent4da7a2c812221d56a430b35139ac40b23fa76f77 (diff)
Add wandb callbacks
-rw-r--r--text_recognizer/callbacks/__init__.py0
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py129
-rw-r--r--training/conf/callbacks/checkpoint.yaml15
-rw-r--r--training/conf/callbacks/default.yaml3
-rw-r--r--training/conf/callbacks/early_stopping.yaml10
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml6
-rw-r--r--training/conf/callbacks/swa.yaml13
-rw-r--r--training/conf/callbacks/wandb.yaml20
8 files changed, 175 insertions, 21 deletions
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/text_recognizer/callbacks/__init__.py
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
new file mode 100644
index 0000000..3936aaf
--- /dev/null
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -0,0 +1,129 @@
+"""Weights and Biases callbacks."""
+from pathlib import Path
+from typing import List
+
+import attr
+import wandb
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.loggers import LoggerCollection, WandbLogger
+
+
+def get_wandb_logger(trainer: Trainer) -> WandbLogger:
+ """Safely get W&B logger from Trainer."""
+
+ if isinstance(trainer.logger, WandbLogger):
+ return trainer.logger
+
+ if isinstance(trainer.logger, LoggerCollection):
+ for logger in trainer.logger:
+ if isinstance(logger, WandbLogger):
+ return logger
+
+ raise Exception("Weight and Biases logger not found for some reason...")
+
+
+@attr.s
+class WatchModel(Callback):
+ """Make W&B watch the model at the beginning of the run."""
+
+ log: str = attr.ib(default="gradients")
+ log_freq: int = attr.ib(default=100)
+
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Watches model weights with wandb."""
+ logger = get_wandb_logger(trainer)
+ logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
+
+
+@attr.s
+class UploadCodeAsArtifact(Callback):
+ """Upload all *.py files to W&B as an artifact, at the beginning of the run."""
+
+ project_dir: Path = attr.ib(converter=Path)
+
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Uploads project code as an artifact."""
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+ artifact = wandb.Artifact("project-source", type="code")
+ for filepath in self.project_dir.glob("**/*.py"):
+ artifact.add_file(filepath)
+
+ experiment.use_artifact(artifact)
+
+
+@attr.s
+class UploadCheckpointAsArtifact(Callback):
+ """Upload checkpoint to wandb as an artifact, at the end of a run."""
+
+ ckpt_dir: Path = attr.ib(converter=Path)
+ upload_best_only: bool = attr.ib()
+
+ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Uploads model checkpoint to W&B."""
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+ ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
+
+ if self.upload_best_only:
+ ckpts.add_file(trainer.checkpoint_callback.best_model_path)
+ else:
+ for ckpt in (self.ckpt_dir).glob("**/*.ckpt"):
+ ckpts.add_file(ckpt)
+
+ experiment.use_artifact(ckpts)
+
+
+@attr.s
+class LogTextPredictions(Callback):
+ """Logs a validation batch with image to text transcription."""
+
+ num_samples: int = attr.ib(default=8)
+ ready: bool = attr.ib(default=True)
+
+ def __attrs_pre_init__(self):
+ super().__init__()
+
+ def on_sanity_check_start(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Sets ready attribute."""
+ self.ready = False
+
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ """Start executing this callback only after all validation sanity checks end."""
+ self.ready = True
+
+ def on_validation_epoch_end(
+ self, trainer: Trainer, pl_module: LightningModule
+ ) -> None:
+ """Logs predictions on validation epoch end."""
+ if not self.ready:
+ return None
+
+ logger = get_wandb_logger(trainer)
+ experiment = logger.experiment
+
+ # Get a validation batch from the validation dataloader.
+ samples = next(iter(trainer.datamodule.val_dataloader()))
+ imgs, labels = samples
+
+ imgs = imgs.to(device=pl_module.device)
+ logits = pl_module(imgs)
+
+ mapping = pl_module.mapping
+ experiment.log(
+ {
+ f"Images/{experiment.name}": [
+ wandb.Image(
+ img,
+ caption=f"Pred: {mapping.get_text(pred)}, Label: {mapping.get_text(label)}",
+ )
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
+ )
+ ]
+ }
+ )
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index f3beb1b..9216715 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -1,6 +1,9 @@
-checkpoint:
- type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
+model_checkpoint:
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ monitor: "val/loss" # name of the logged metric which determines when model is improving
+ save_top_k: 1 # save k best models (determined by above metric)
+ save_last: True # additionaly always save model from last epoch
+ mode: "min" # can be "max" or "min"
+ verbose: False
+ dirpath: "checkpoints/"
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
new file mode 100644
index 0000000..658fc03
--- /dev/null
+++ b/training/conf/callbacks/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - checkpoint
+ - learning_rate_monitor
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
index ec671fd..4cd5aa1 100644
--- a/training/conf/callbacks/early_stopping.yaml
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -1,6 +1,6 @@
early_stopping:
- type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ _target_: pytorch_lightning.callbacks.EarlyStopping
+ monitor: "val/loss" # name of the logged metric which determines when model is improving
+ patience: 16 # how many epochs of not improving until training stops
+ mode: "min" # can be "max" or "min"
+ min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
index 11a5ecf..4a14e1f 100644
--- a/training/conf/callbacks/learning_rate_monitor.yaml
+++ b/training/conf/callbacks/learning_rate_monitor.yaml
@@ -1,4 +1,4 @@
learning_rate_monitor:
- type: LearningRateMonitor
- args:
- logging_interval: step
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
index 92d9e6b..73f8c66 100644
--- a/training/conf/callbacks/swa.yaml
+++ b/training/conf/callbacks/swa.yaml
@@ -1,8 +1,7 @@
stochastic_weight_averaging:
- type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
new file mode 100644
index 0000000..2d56bfa
--- /dev/null
+++ b/training/conf/callbacks/wandb.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default.yaml
+
+watch_model:
+ _target_: text_recognizer.callbacks.wandb_callbacks.WatchModel
+ log: "all"
+ log_freq: 100
+
+upload_code_as_artifact:
+ _target_: text_recognizer.callbacks.wandb_callbacks.UploadCodeAsArtifact
+ project_dir: ${work_dir}/text_recognizer
+
+upload_ckpts_as_artifact:
+ _target_: text_recognizer.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ ckpt_dir: "checkpoints/"
+ upload_best_only: True
+
+log_text_predictions:
+ _target_: text_recognizer.callbacks.wandb_callbacks.LogTextPredictions
+ num_samples: 8