diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-24 23:09:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-24 23:09:20 +0200 |
commit | 4e60c836fb710baceba570c28c06437db3ad5c9b (patch) | |
tree | 21caf6d1792bd83a47fb3d372ee7120211e83f18 /training/configs | |
parent | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (diff) |
Implementing CoaT transformer, continue tomorrow...
Diffstat (limited to 'training/configs')
-rw-r--r-- | training/configs/vqvae.yaml | 44 |
1 files changed, 22 insertions, 22 deletions
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml index 90082f7..a7acb3a 100644 --- a/training/configs/vqvae.yaml +++ b/training/configs/vqvae.yaml @@ -5,16 +5,16 @@ network: type: VQVAE args: in_channels: 1 - channels: [32, 64, 96] + channels: [32, 64, 64] kernel_sizes: [4, 4, 4] strides: [2, 2, 2] num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 1024 + embedding_dim: 128 + num_embeddings: 512 upsampling: null beta: 0.25 activation: leaky_relu - dropout_rate: 0.1 + dropout_rate: 0.2 model: desc: Configuration of the PyTorch Lightning model. @@ -33,8 +33,8 @@ model: interval: &interval step max_lr: 1.0e-3 three_phase: true - epochs: 512 - steps_per_epoch: 317 # num_samples / batch_size + epochs: 64 + steps_per_epoch: 633 # num_samples / batch_size criterion: type: MSELoss args: @@ -46,7 +46,7 @@ data: desc: Configuration of the training/test data. type: IAMExtendedParagraphs args: - batch_size: 64 + batch_size: 32 num_workers: 12 train_fraction: 0.8 augment: true @@ -57,33 +57,33 @@ callbacks: monitor: val_loss mode: min save_last: true - # - type: StochasticWeightAveraging - # args: - # swa_epoch_start: 0.8 - # swa_lrs: 0.05 - # annealing_epochs: 10 - # annealing_strategy: cos - # device: null + - type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null - type: LearningRateMonitor args: logging_interval: *interval - - type: EarlyStopping - args: - monitor: val_loss - mode: min - patience: 10 + # - type: EarlyStopping + # args: + # monitor: val_loss + # mode: min + # patience: 10 trainer: desc: Configuration of the PyTorch Lightning Trainer. args: - stochastic_weight_avg: false # true + stochastic_weight_avg: true auto_scale_batch_size: binsearch gradient_clip_val: 0 fast_dev_run: false gpus: 1 precision: 16 - max_epochs: 512 + max_epochs: 64 terminate_on_nan: true - weights_summary: full + weights_summary: top load_checkpoint: null |