summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-04 22:58:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-04 22:58:07 +0200
commit969e1d5e179d9c42ffae0c9b12c9bd3be6091360 (patch)
treeab849c38bc9b863afad85fd04d6f618031000e6f /training
parent4da7a2c812221d56a430b35139ac40b23fa76f77 (diff)
Add wandb callbacks
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/checkpoint.yaml15
-rw-r--r--training/conf/callbacks/default.yaml3
-rw-r--r--training/conf/callbacks/early_stopping.yaml10
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml6
-rw-r--r--training/conf/callbacks/swa.yaml13
-rw-r--r--training/conf/callbacks/wandb.yaml20
6 files changed, 46 insertions, 21 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index f3beb1b..9216715 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -1,6 +1,9 @@
-checkpoint:
- type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
+model_checkpoint:
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ monitor: "val/loss" # name of the logged metric which determines when model is improving
+ save_top_k: 1 # save k best models (determined by above metric)
+ save_last: True # additionaly always save model from last epoch
+ mode: "min" # can be "max" or "min"
+ verbose: False
+ dirpath: "checkpoints/"
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
new file mode 100644
index 0000000..658fc03
--- /dev/null
+++ b/training/conf/callbacks/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - checkpoint
+ - learning_rate_monitor
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
index ec671fd..4cd5aa1 100644
--- a/training/conf/callbacks/early_stopping.yaml
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -1,6 +1,6 @@
early_stopping:
- type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ _target_: pytorch_lightning.callbacks.EarlyStopping
+ monitor: "val/loss" # name of the logged metric which determines when model is improving
+ patience: 16 # how many epochs of not improving until training stops
+ mode: "min" # can be "max" or "min"
+ min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
index 11a5ecf..4a14e1f 100644
--- a/training/conf/callbacks/learning_rate_monitor.yaml
+++ b/training/conf/callbacks/learning_rate_monitor.yaml
@@ -1,4 +1,4 @@
learning_rate_monitor:
- type: LearningRateMonitor
- args:
- logging_interval: step
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
index 92d9e6b..73f8c66 100644
--- a/training/conf/callbacks/swa.yaml
+++ b/training/conf/callbacks/swa.yaml
@@ -1,8 +1,7 @@
stochastic_weight_averaging:
- type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
new file mode 100644
index 0000000..2d56bfa
--- /dev/null
+++ b/training/conf/callbacks/wandb.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default.yaml
+
+watch_model:
+ _target_: text_recognizer.callbacks.wandb_callbacks.WatchModel
+ log: "all"
+ log_freq: 100
+
+upload_code_as_artifact:
+ _target_: text_recognizer.callbacks.wandb_callbacks.UploadCodeAsArtifact
+ project_dir: ${work_dir}/text_recognizer
+
+upload_ckpts_as_artifact:
+ _target_: text_recognizer.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ ckpt_dir: "checkpoints/"
+ upload_best_only: True
+
+log_text_predictions:
+ _target_: text_recognizer.callbacks.wandb_callbacks.LogTextPredictions
+ num_samples: 8