summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-27 16:52:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-27 16:52:55 +0200
commit0ba945e84d11a07ac95fdf8495f2ff278215adb9 (patch)
treeb22bcecf37416fbdfbe6a5ef178471e1a5751c3e
parent22e36513dd43d2e2ca82ca28a1ea757c5663676a (diff)
Changed structure of callbacks
-rw-r--r--training/conf/callbacks/checkpoint.yaml5
-rw-r--r--training/conf/callbacks/default.yaml14
-rw-r--r--training/conf/callbacks/early_stopping.yaml5
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml3
-rw-r--r--training/conf/callbacks/swa.yaml23
-rw-r--r--training/conf/cnn_transformer.yaml90
-rw-r--r--training/conf/config.yaml12
7 files changed, 27 insertions, 125 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
new file mode 100644
index 0000000..afc536f
--- /dev/null
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -0,0 +1,5 @@
+type: ModelCheckpoint
+args:
+ monitor: val_loss
+ mode: min
+ save_last: true
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
deleted file mode 100644
index 74dc30c..0000000
--- a/training/conf/callbacks/default.yaml
+++ /dev/null
@@ -1,14 +0,0 @@
-# @package _group_
-- type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
-- type: LearningRateMonitor
- args:
- logging_interval: step
-# - type: EarlyStopping
-# args:
-# monitor: val_loss
-# mode: min
-# patience: 10
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
new file mode 100644
index 0000000..caab824
--- /dev/null
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -0,0 +1,5 @@
+type: EarlyStopping
+args:
+ monitor: val_loss
+ mode: min
+ patience: 10
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
new file mode 100644
index 0000000..003ab7a
--- /dev/null
+++ b/training/conf/callbacks/learning_rate_monitor.yaml
@@ -0,0 +1,3 @@
+type: LearningRateMonitor
+args:
+ logging_interval: step
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
index 144ad6e..279ca69 100644
--- a/training/conf/callbacks/swa.yaml
+++ b/training/conf/callbacks/swa.yaml
@@ -1,16 +1,7 @@
-# @package _group_
-- type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
-- type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-- type: LearningRateMonitor
- args:
- logging_interval: step
+type: StochasticWeightAveraging
+args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
diff --git a/training/conf/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml
deleted file mode 100644
index a4f16df..0000000
--- a/training/conf/cnn_transformer.yaml
+++ /dev/null
@@ -1,90 +0,0 @@
-seed: 4711
-
-network:
- desc: Configuration of the PyTorch neural network.
- type: CNNTransformer
- args:
- encoder:
- type: EfficientNet
- args: null
- num_decoder_layers: 4
- vocab_size: 84
- hidden_dim: 256
- num_heads: 4
- expansion_dim: 1024
- dropout_rate: 0.1
- transformer_activation: glu
-
-model:
- desc: Configuration of the PyTorch Lightning model.
- type: LitTransformerModel
- args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: &interval step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 512
- steps_per_epoch: 1246 # num_samples / batch_size
- criterion:
- type: CrossEntropyLoss
- args:
- weight: null
- ignore_index: -100
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
-
-data:
- desc: Configuration of the training/test data.
- type: IAMExtendedParagraphs
- args:
- batch_size: 8
- num_workers: 12
- train_fraction: 0.8
- augment: true
-
-callbacks:
- - type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
- # - type: StochasticWeightAveraging
- # args:
- # swa_epoch_start: 0.8
- # swa_lrs: 0.05
- # annealing_epochs: 10
- # annealing_strategy: cos
- # device: null
- - type: LearningRateMonitor
- args:
- logging_interval: *interval
- # - type: EarlyStopping
- # args:
- # monitor: val_loss
- # mode: min
- # patience: 10
-
-trainer:
- desc: Configuration of the PyTorch Lightning Trainer.
- args:
- stochastic_weight_avg: false
- auto_scale_batch_size: binsearch
- gradient_clip_val: 0
- fast_dev_run: true
- gpus: 1
- precision: 16
- max_epochs: 512
- terminate_on_nan: true
- weights_summary: top
-
-load_checkpoint: null
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 11adeb7..c413a1a 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,6 +1,8 @@
defaults:
- - network: vqvae
- - model: lit_vqvae
- - dataset: iam_extended_paragraphs
- - trainer: default
- - callbacks: default
+ - network: vqvae
+ - model: lit_vqvae
+ - dataset: iam_extended_paragraphs
+ - trainer: default
+ - callbacks:
+ - checkpoint
+ - learning_rate_monitor