From 4da7a2c812221d56a430b35139ac40b23fa76f77 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 29 Jun 2021 22:54:52 +0200 Subject: Refactor of config, more granular --- training/conf/callbacks/checkpoint.yaml | 5 +++-- training/conf/callbacks/early_stopping.yaml | 5 +++-- training/conf/callbacks/learning_rate_monitor.yaml | 5 +++-- training/conf/callbacks/swa.yaml | 5 +++-- training/conf/config.yaml | 6 ++++++ training/conf/criterion/mse.yaml | 3 +++ training/conf/dataset/iam_extended_paragraphs.yaml | 9 ++++----- training/conf/lr_scheduler/one_cycle.yaml | 8 ++++++++ training/conf/model/lit_vqvae.yaml | 23 +--------------------- training/conf/network/vqvae.yaml | 23 +++++++++++----------- training/conf/optimizer/madgrad.yaml | 6 ++++++ training/conf/trainer/default.yaml | 23 ++++++++++------------ 12 files changed, 61 insertions(+), 60 deletions(-) create mode 100644 training/conf/criterion/mse.yaml create mode 100644 training/conf/lr_scheduler/one_cycle.yaml create mode 100644 training/conf/optimizer/madgrad.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index afc536f..f3beb1b 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -1,5 +1,6 @@ -type: ModelCheckpoint -args: +checkpoint: + type: ModelCheckpoint + args: monitor: val_loss mode: min save_last: true diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml index caab824..ec671fd 100644 --- a/training/conf/callbacks/early_stopping.yaml +++ b/training/conf/callbacks/early_stopping.yaml @@ -1,5 +1,6 @@ -type: EarlyStopping -args: +early_stopping: + type: EarlyStopping + args: monitor: val_loss mode: min patience: 10 diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml index 003ab7a..11a5ecf 100644 --- a/training/conf/callbacks/learning_rate_monitor.yaml +++ b/training/conf/callbacks/learning_rate_monitor.yaml @@ -1,3 +1,4 @@ -type: LearningRateMonitor -args: +learning_rate_monitor: + type: LearningRateMonitor + args: logging_interval: step diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml index 279ca69..92d9e6b 100644 --- a/training/conf/callbacks/swa.yaml +++ b/training/conf/callbacks/swa.yaml @@ -1,5 +1,6 @@ -type: StochasticWeightAveraging -args: +stochastic_weight_averaging: + type: StochasticWeightAveraging + args: swa_epoch_start: 0.8 swa_lrs: 0.05 annealing_epochs: 10 diff --git a/training/conf/config.yaml b/training/conf/config.yaml index c413a1a..b43e375 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,8 +1,14 @@ defaults: - network: vqvae + - criterion: mse + - optimizer: madgrad + - lr_scheduler: one_cycle - model: lit_vqvae - dataset: iam_extended_paragraphs - trainer: default - callbacks: - checkpoint - learning_rate_monitor + +load_checkpoint: null +logging: INFO diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml new file mode 100644 index 0000000..4d89cbc --- /dev/null +++ b/training/conf/criterion/mse.yaml @@ -0,0 +1,3 @@ +type: MSELoss +args: + reduction: mean diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml index 6bd7fc9..6439a15 100644 --- a/training/conf/dataset/iam_extended_paragraphs.yaml +++ b/training/conf/dataset/iam_extended_paragraphs.yaml @@ -1,7 +1,6 @@ -# @package _group_ type: IAMExtendedParagraphs args: - batch_size: 32 - num_workers: 12 - train_fraction: 0.8 - augment: true + batch_size: 32 + num_workers: 12 + train_fraction: 0.8 + augment: true diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml new file mode 100644 index 0000000..60a6f27 --- /dev/null +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -0,0 +1,8 @@ +type: OneCycleLR +args: + interval: step + max_lr: 1.0e-3 + three_phase: true + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size +monitor: val_loss diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 90780b7..7136dbd 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,24 +1,3 @@ -# @package _group_ type: LitVQVAEModel args: - optimizer: - type: MADGRAD - args: - lr: 1.0e-3 - momentum: 0.9 - weight_decay: 0 - eps: 1.0e-6 - lr_scheduler: - type: OneCycleLR - args: - interval: step - max_lr: 1.0e-3 - three_phase: true - epochs: 64 - steps_per_epoch: 633 # num_samples / batch_size - criterion: - type: MSELoss - args: - reduction: mean - monitor: val_loss - mapping: sentence_piece + mapping: sentence_piece diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 288d2aa..22eebf8 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,14 +1,13 @@ -# @package _group_ type: VQVAE args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 + in_channels: 1 + channels: [64, 96] + kernel_sizes: [4, 4] + strides: [2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 256 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.2 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml new file mode 100644 index 0000000..2f2cff9 --- /dev/null +++ b/training/conf/optimizer/madgrad.yaml @@ -0,0 +1,6 @@ +type: MADGRAD +args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 3a88c6a..5797741 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -1,19 +1,16 @@ -# @package _group_ seed: 4711 -load_checkpoint: null wandb: false tune: false train: true test: true -logging: INFO args: - stochastic_weight_avg: false - auto_scale_batch_size: binsearch - auto_lr_find: false - gradient_clip_val: 0 - fast_dev_run: false - gpus: 1 - precision: 16 - max_epochs: 64 - terminate_on_nan: true - weights_summary: top + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 64 + terminate_on_nan: true + weights_summary: top -- cgit v1.2.3-70-g09d2