summaryrefslogtreecommitdiff
path: root/training/conf/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/callbacks')
-rw-r--r--training/conf/callbacks/htr.yaml6
-rw-r--r--training/conf/callbacks/lightning/checkpoint.yaml9
-rw-r--r--training/conf/callbacks/lightning/early_stopping.yaml6
-rw-r--r--training/conf/callbacks/lightning/learning_rate_monitor.yaml4
-rw-r--r--training/conf/callbacks/vae.yaml6
-rw-r--r--training/conf/callbacks/wandb/checkpoints.yaml4
-rw-r--r--training/conf/callbacks/wandb/config.yaml2
-rw-r--r--training/conf/callbacks/wandb/predictions.yaml3
-rw-r--r--training/conf/callbacks/wandb/reconstructions.yaml4
-rw-r--r--training/conf/callbacks/wandb/swa.yaml7
-rw-r--r--training/conf/callbacks/wandb/watch.yaml4
11 files changed, 55 insertions, 0 deletions
diff --git a/training/conf/callbacks/htr.yaml b/training/conf/callbacks/htr.yaml
new file mode 100644
index 0000000..51c68c5
--- /dev/null
+++ b/training/conf/callbacks/htr.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb: watch
+ - wandb: config
+ - wandb: checkpoints
+ - wandb: htr_predictions
diff --git a/training/conf/callbacks/lightning/checkpoint.yaml b/training/conf/callbacks/lightning/checkpoint.yaml
new file mode 100644
index 0000000..b4101d8
--- /dev/null
+++ b/training/conf/callbacks/lightning/checkpoint.yaml
@@ -0,0 +1,9 @@
+model_checkpoint:
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ monitor: val/loss # name of the logged metric which determines when model is improving
+ save_top_k: 1 # save k best models (determined by above metric)
+ save_last: true # additionaly always save model from last epoch
+ mode: min # can be "max" or "min"
+ verbose: false
+ dirpath: checkpoints/
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/lightning/early_stopping.yaml b/training/conf/callbacks/lightning/early_stopping.yaml
new file mode 100644
index 0000000..a188df3
--- /dev/null
+++ b/training/conf/callbacks/lightning/early_stopping.yaml
@@ -0,0 +1,6 @@
+early_stopping:
+ _target_: pytorch_lightning.callbacks.EarlyStopping
+ monitor: val/loss # name of the logged metric which determines when model is improving
+ patience: 16 # how many epochs of not improving until training stops
+ mode: min # can be "max" or "min"
+ min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/training/conf/callbacks/lightning/learning_rate_monitor.yaml b/training/conf/callbacks/lightning/learning_rate_monitor.yaml
new file mode 100644
index 0000000..4a14e1f
--- /dev/null
+++ b/training/conf/callbacks/lightning/learning_rate_monitor.yaml
@@ -0,0 +1,4 @@
+learning_rate_monitor:
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
diff --git a/training/conf/callbacks/vae.yaml b/training/conf/callbacks/vae.yaml
new file mode 100644
index 0000000..eec2c1f
--- /dev/null
+++ b/training/conf/callbacks/vae.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb: watch
+ - wandb: checkpoints
+ - wandb: reconstructions
+ - wandb: config
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml
new file mode 100644
index 0000000..a4a16ff
--- /dev/null
+++ b/training/conf/callbacks/wandb/checkpoints.yaml
@@ -0,0 +1,4 @@
+upload_ckpts_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ ckpt_dir: checkpoints/
+ upload_best_only: true
diff --git a/training/conf/callbacks/wandb/config.yaml b/training/conf/callbacks/wandb/config.yaml
new file mode 100644
index 0000000..747a7c6
--- /dev/null
+++ b/training/conf/callbacks/wandb/config.yaml
@@ -0,0 +1,2 @@
+upload_code_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact
diff --git a/training/conf/callbacks/wandb/predictions.yaml b/training/conf/callbacks/wandb/predictions.yaml
new file mode 100644
index 0000000..573fa96
--- /dev/null
+++ b/training/conf/callbacks/wandb/predictions.yaml
@@ -0,0 +1,3 @@
+log_text_predictions:
+ _target_: callbacks.wandb_callbacks.LogTextPredictions
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb/reconstructions.yaml b/training/conf/callbacks/wandb/reconstructions.yaml
new file mode 100644
index 0000000..92f2d12
--- /dev/null
+++ b/training/conf/callbacks/wandb/reconstructions.yaml
@@ -0,0 +1,4 @@
+log_image_reconstruction:
+ _target_: callbacks.wandb_callbacks.LogReconstuctedImages
+ num_samples: 8
+ use_sigmoid: true
diff --git a/training/conf/callbacks/wandb/swa.yaml b/training/conf/callbacks/wandb/swa.yaml
new file mode 100644
index 0000000..73f8c66
--- /dev/null
+++ b/training/conf/callbacks/wandb/swa.yaml
@@ -0,0 +1,7 @@
+stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml
new file mode 100644
index 0000000..511608c
--- /dev/null
+++ b/training/conf/callbacks/wandb/watch.yaml
@@ -0,0 +1,4 @@
+watch_model:
+ _target_: callbacks.wandb_callbacks.WatchModel
+ log: all
+ log_freq: 100