From 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 22 Apr 2021 08:15:58 +0200 Subject: Fixed training script, able to train vqvae --- training/configs/image_transformer.yaml | 2 + training/configs/vqvae.yaml | 89 +++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 training/configs/vqvae.yaml (limited to 'training/configs') diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml index 228e53f..e6637f2 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/image_transformer.yaml @@ -85,3 +85,5 @@ trainer: max_epochs: 512 terminate_on_nan: true weights_summary: true + +load_checkpoint: null diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml new file mode 100644 index 0000000..90082f7 --- /dev/null +++ b/training/configs/vqvae.yaml @@ -0,0 +1,89 @@ +seed: 4711 + +network: + desc: Configuration of the PyTorch neural network. + type: VQVAE + args: + in_channels: 1 + channels: [32, 64, 96] + kernel_sizes: [4, 4, 4] + strides: [2, 2, 2] + num_residual_layers: 2 + embedding_dim: 64 + num_embeddings: 1024 + upsampling: null + beta: 0.25 + activation: leaky_relu + dropout_rate: 0.1 + +model: + desc: Configuration of the PyTorch Lightning model. + type: LitVQVAEModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 317 # num_samples / batch_size + criterion: + type: MSELoss + args: + reduction: mean + monitor: val_loss + mapping: sentence_piece + +data: + desc: Configuration of the training/test data. + type: IAMExtendedParagraphs + args: + batch_size: 64 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true + # - type: StochasticWeightAveraging + # args: + # swa_epoch_start: 0.8 + # swa_lrs: 0.05 + # annealing_epochs: 10 + # annealing_strategy: cos + # device: null + - type: LearningRateMonitor + args: + logging_interval: *interval + - type: EarlyStopping + args: + monitor: val_loss + mode: min + patience: 10 + +trainer: + desc: Configuration of the PyTorch Lightning Trainer. + args: + stochastic_weight_avg: false # true + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: 512 + terminate_on_nan: true + weights_summary: full + +load_checkpoint: null -- cgit v1.2.3-70-g09d2