summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-29 22:54:52 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-29 22:54:52 +0200
commit4da7a2c812221d56a430b35139ac40b23fa76f77 (patch)
tree69528c08aa97f57918bc23cd0cc2ab6388ee0470 /training/conf
parente22453c6e6ff10a610348778f8107799c1125d3b (diff)
Refactor of config, more granular
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/callbacks/checkpoint.yaml5
-rw-r--r--training/conf/callbacks/early_stopping.yaml5
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml5
-rw-r--r--training/conf/callbacks/swa.yaml5
-rw-r--r--training/conf/config.yaml6
-rw-r--r--training/conf/criterion/mse.yaml3
-rw-r--r--training/conf/dataset/iam_extended_paragraphs.yaml9
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml8
-rw-r--r--training/conf/model/lit_vqvae.yaml23
-rw-r--r--training/conf/network/vqvae.yaml23
-rw-r--r--training/conf/optimizer/madgrad.yaml6
-rw-r--r--training/conf/trainer/default.yaml23
12 files changed, 61 insertions, 60 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index afc536f..f3beb1b 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -1,5 +1,6 @@
-type: ModelCheckpoint
-args:
+checkpoint:
+ type: ModelCheckpoint
+ args:
monitor: val_loss
mode: min
save_last: true
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
index caab824..ec671fd 100644
--- a/training/conf/callbacks/early_stopping.yaml
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -1,5 +1,6 @@
-type: EarlyStopping
-args:
+early_stopping:
+ type: EarlyStopping
+ args:
monitor: val_loss
mode: min
patience: 10
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
index 003ab7a..11a5ecf 100644
--- a/training/conf/callbacks/learning_rate_monitor.yaml
+++ b/training/conf/callbacks/learning_rate_monitor.yaml
@@ -1,3 +1,4 @@
-type: LearningRateMonitor
-args:
+learning_rate_monitor:
+ type: LearningRateMonitor
+ args:
logging_interval: step
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
index 279ca69..92d9e6b 100644
--- a/training/conf/callbacks/swa.yaml
+++ b/training/conf/callbacks/swa.yaml
@@ -1,5 +1,6 @@
-type: StochasticWeightAveraging
-args:
+stochastic_weight_averaging:
+ type: StochasticWeightAveraging
+ args:
swa_epoch_start: 0.8
swa_lrs: 0.05
annealing_epochs: 10
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index c413a1a..b43e375 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,8 +1,14 @@
defaults:
- network: vqvae
+ - criterion: mse
+ - optimizer: madgrad
+ - lr_scheduler: one_cycle
- model: lit_vqvae
- dataset: iam_extended_paragraphs
- trainer: default
- callbacks:
- checkpoint
- learning_rate_monitor
+
+load_checkpoint: null
+logging: INFO
diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml
new file mode 100644
index 0000000..4d89cbc
--- /dev/null
+++ b/training/conf/criterion/mse.yaml
@@ -0,0 +1,3 @@
+type: MSELoss
+args:
+ reduction: mean
diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml
index 6bd7fc9..6439a15 100644
--- a/training/conf/dataset/iam_extended_paragraphs.yaml
+++ b/training/conf/dataset/iam_extended_paragraphs.yaml
@@ -1,7 +1,6 @@
-# @package _group_
type: IAMExtendedParagraphs
args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
+ batch_size: 32
+ num_workers: 12
+ train_fraction: 0.8
+ augment: true
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
new file mode 100644
index 0000000..60a6f27
--- /dev/null
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -0,0 +1,8 @@
+type: OneCycleLR
+args:
+ interval: step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
+monitor: val_loss
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 90780b7..7136dbd 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,24 +1,3 @@
-# @package _group_
type: LitVQVAEModel
args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
- criterion:
- type: MSELoss
- args:
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
+ mapping: sentence_piece
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 288d2aa..22eebf8 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,14 +1,13 @@
-# @package _group_
type: VQVAE
args:
- in_channels: 1
- channels: [64, 96]
- kernel_sizes: [4, 4]
- strides: [2, 2]
- num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 256
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
+ in_channels: 1
+ channels: [64, 96]
+ kernel_sizes: [4, 4]
+ strides: [2, 2]
+ num_residual_layers: 2
+ embedding_dim: 64
+ num_embeddings: 256
+ upsampling: null
+ beta: 0.25
+ activation: leaky_relu
+ dropout_rate: 0.2
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
new file mode 100644
index 0000000..2f2cff9
--- /dev/null
+++ b/training/conf/optimizer/madgrad.yaml
@@ -0,0 +1,6 @@
+type: MADGRAD
+args:
+ lr: 1.0e-3
+ momentum: 0.9
+ weight_decay: 0
+ eps: 1.0e-6
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 3a88c6a..5797741 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -1,19 +1,16 @@
-# @package _group_
seed: 4711
-load_checkpoint: null
wandb: false
tune: false
train: true
test: true
-logging: INFO
args:
- stochastic_weight_avg: false
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
+ stochastic_weight_avg: false
+ auto_scale_batch_size: binsearch
+ auto_lr_find: false
+ gradient_clip_val: 0
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: 64
+ terminate_on_nan: true
+ weights_summary: top