From 1d0977585f01c42e9f6280559a1a98037907a62e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 May 2021 13:51:15 +0200 Subject: Implemented training script with hydra --- training/conf/callbacks/default.yaml | 14 ++++ training/conf/callbacks/swa.yaml | 16 ++++ training/conf/cnn_transformer.yaml | 90 ++++++++++++++++++++++ training/conf/config.yaml | 6 ++ training/conf/dataset/iam_extended_paragraphs.yaml | 7 ++ training/conf/model/lit_vqvae.yaml | 24 ++++++ training/conf/network/vqvae.yaml | 14 ++++ training/conf/trainer/default.yaml | 18 +++++ 8 files changed, 189 insertions(+) create mode 100644 training/conf/callbacks/default.yaml create mode 100644 training/conf/callbacks/swa.yaml create mode 100644 training/conf/cnn_transformer.yaml create mode 100644 training/conf/config.yaml create mode 100644 training/conf/dataset/iam_extended_paragraphs.yaml create mode 100644 training/conf/model/lit_vqvae.yaml create mode 100644 training/conf/network/vqvae.yaml create mode 100644 training/conf/trainer/default.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml new file mode 100644 index 0000000..74dc30c --- /dev/null +++ b/training/conf/callbacks/default.yaml @@ -0,0 +1,14 @@ +# @package _group_ +- type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true +- type: LearningRateMonitor + args: + logging_interval: step +# - type: EarlyStopping +# args: +# monitor: val_loss +# mode: min +# patience: 10 diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml new file mode 100644 index 0000000..144ad6e --- /dev/null +++ b/training/conf/callbacks/swa.yaml @@ -0,0 +1,16 @@ +# @package _group_ +- type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true +- type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null +- type: LearningRateMonitor + args: + logging_interval: step diff --git a/training/conf/cnn_transformer.yaml b/training/conf/cnn_transformer.yaml new file mode 100644 index 0000000..a4f16df --- /dev/null +++ b/training/conf/cnn_transformer.yaml @@ -0,0 +1,90 @@ +seed: 4711 + +network: + desc: Configuration of the PyTorch neural network. + type: CNNTransformer + args: + encoder: + type: EfficientNet + args: null + num_decoder_layers: 4 + vocab_size: 84 + hidden_dim: 256 + num_heads: 4 + expansion_dim: 1024 + dropout_rate: 0.1 + transformer_activation: glu + +model: + desc: Configuration of the PyTorch Lightning model. + type: LitTransformerModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 1246 # num_samples / batch_size + criterion: + type: CrossEntropyLoss + args: + weight: null + ignore_index: -100 + reduction: mean + monitor: val_loss + mapping: sentence_piece + +data: + desc: Configuration of the training/test data. + type: IAMExtendedParagraphs + args: + batch_size: 8 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true + # - type: StochasticWeightAveraging + # args: + # swa_epoch_start: 0.8 + # swa_lrs: 0.05 + # annealing_epochs: 10 + # annealing_strategy: cos + # device: null + - type: LearningRateMonitor + args: + logging_interval: *interval + # - type: EarlyStopping + # args: + # monitor: val_loss + # mode: min + # patience: 10 + +trainer: + desc: Configuration of the PyTorch Lightning Trainer. + args: + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: true + gpus: 1 + precision: 16 + max_epochs: 512 + terminate_on_nan: true + weights_summary: top + +load_checkpoint: null diff --git a/training/conf/config.yaml b/training/conf/config.yaml new file mode 100644 index 0000000..11adeb7 --- /dev/null +++ b/training/conf/config.yaml @@ -0,0 +1,6 @@ +defaults: + - network: vqvae + - model: lit_vqvae + - dataset: iam_extended_paragraphs + - trainer: default + - callbacks: default diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml new file mode 100644 index 0000000..6bd7fc9 --- /dev/null +++ b/training/conf/dataset/iam_extended_paragraphs.yaml @@ -0,0 +1,7 @@ +# @package _group_ +type: IAMExtendedParagraphs +args: + batch_size: 32 + num_workers: 12 + train_fraction: 0.8 + augment: true diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml new file mode 100644 index 0000000..90780b7 --- /dev/null +++ b/training/conf/model/lit_vqvae.yaml @@ -0,0 +1,24 @@ +# @package _group_ +type: LitVQVAEModel +args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: step + max_lr: 1.0e-3 + three_phase: true + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size + criterion: + type: MSELoss + args: + reduction: mean + monitor: val_loss + mapping: sentence_piece diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml new file mode 100644 index 0000000..8c30bbd --- /dev/null +++ b/training/conf/network/vqvae.yaml @@ -0,0 +1,14 @@ +# @package _group_ +type: VQVAE +args: + in_channels: 1 + channels: [32, 64, 64] + kernel_sizes: [4, 4, 4] + strides: [2, 2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 256 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.2 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml new file mode 100644 index 0000000..82afd93 --- /dev/null +++ b/training/conf/trainer/default.yaml @@ -0,0 +1,18 @@ +# @package _group_ +seed: 4711 +load_checkpoint: null +wandb: false +tune: false +train: true +test: true +logging: INFO +args: + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 64 + terminate_on_nan: true + weights_summary: top -- cgit v1.2.3-70-g09d2