diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-25 23:32:50 +0200 |
commit | 9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch) | |
tree | 44e31b0a7c58597d603ac29a693462aae4b6e9b0 /training/configs/cnn_transformer.yaml | |
parent | 4e60c836fb710baceba570c28c06437db3ad5c9b (diff) |
Efficient net and non working transformer model.
Diffstat (limited to 'training/configs/cnn_transformer.yaml')
-rw-r--r-- | training/configs/cnn_transformer.yaml | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/training/configs/cnn_transformer.yaml b/training/configs/cnn_transformer.yaml new file mode 100644 index 0000000..a4f16df --- /dev/null +++ b/training/configs/cnn_transformer.yaml @@ -0,0 +1,90 @@ +seed: 4711 + +network: + desc: Configuration of the PyTorch neural network. + type: CNNTransformer + args: + encoder: + type: EfficientNet + args: null + num_decoder_layers: 4 + vocab_size: 84 + hidden_dim: 256 + num_heads: 4 + expansion_dim: 1024 + dropout_rate: 0.1 + transformer_activation: glu + +model: + desc: Configuration of the PyTorch Lightning model. + type: LitTransformerModel + args: + optimizer: + type: MADGRAD + args: + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + lr_scheduler: + type: OneCycleLR + args: + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 1246 # num_samples / batch_size + criterion: + type: CrossEntropyLoss + args: + weight: null + ignore_index: -100 + reduction: mean + monitor: val_loss + mapping: sentence_piece + +data: + desc: Configuration of the training/test data. + type: IAMExtendedParagraphs + args: + batch_size: 8 + num_workers: 12 + train_fraction: 0.8 + augment: true + +callbacks: + - type: ModelCheckpoint + args: + monitor: val_loss + mode: min + save_last: true + # - type: StochasticWeightAveraging + # args: + # swa_epoch_start: 0.8 + # swa_lrs: 0.05 + # annealing_epochs: 10 + # annealing_strategy: cos + # device: null + - type: LearningRateMonitor + args: + logging_interval: *interval + # - type: EarlyStopping + # args: + # monitor: val_loss + # mode: min + # patience: 10 + +trainer: + desc: Configuration of the PyTorch Lightning Trainer. + args: + stochastic_weight_avg: false + auto_scale_batch_size: binsearch + gradient_clip_val: 0 + fast_dev_run: true + gpus: 1 + precision: 16 + max_epochs: 512 + terminate_on_nan: true + weights_summary: top + +load_checkpoint: null |