summaryrefslogtreecommitdiff
path: root/training/configs/vqvae.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-02 13:51:15 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-02 13:51:15 +0200
commit1d0977585f01c42e9f6280559a1a98037907a62e (patch)
tree7e86dd71b163f3138ed2658cb52c44e805f21539 /training/configs/vqvae.yaml
parent58ae7154aa945cfe5a46592cc1dfb28f0a4e51b3 (diff)
Implemented training script with hydra
Diffstat (limited to 'training/configs/vqvae.yaml')
-rw-r--r--training/configs/vqvae.yaml89
1 files changed, 0 insertions, 89 deletions
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
deleted file mode 100644
index 13d7c97..0000000
--- a/training/configs/vqvae.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-seed: 4711
-
-network:
- desc: Configuration of the PyTorch neural network.
- type: VQVAE
- args:
- in_channels: 1
- channels: [32, 64, 64, 96, 96]
- kernel_sizes: [4, 4, 4, 4, 4]
- strides: [2, 2, 2, 2, 2]
- num_residual_layers: 2
- embedding_dim: 512
- num_embeddings: 1024
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
-
-model:
- desc: Configuration of the PyTorch Lightning model.
- type: LitVQVAEModel
- args:
- optimizer:
- type: MADGRAD
- args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
- lr_scheduler:
- type: OneCycleLR
- args:
- interval: &interval step
- max_lr: 1.0e-3
- three_phase: true
- epochs: 64
- steps_per_epoch: 633 # num_samples / batch_size
- criterion:
- type: MSELoss
- args:
- reduction: mean
- monitor: val_loss
- mapping: sentence_piece
-
-data:
- desc: Configuration of the training/test data.
- type: IAMExtendedParagraphs
- args:
- batch_size: 32
- num_workers: 12
- train_fraction: 0.8
- augment: true
-
-callbacks:
- - type: ModelCheckpoint
- args:
- monitor: val_loss
- mode: min
- save_last: true
- - type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
- - type: LearningRateMonitor
- args:
- logging_interval: *interval
- # - type: EarlyStopping
- # args:
- # monitor: val_loss
- # mode: min
- # patience: 10
-
-trainer:
- desc: Configuration of the PyTorch Lightning Trainer.
- args:
- stochastic_weight_avg: true
- auto_scale_batch_size: binsearch
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
-
-load_checkpoint: null