diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-11 21:48:34 +0200 | 
| commit | 0ab820d3595e4f17d4f1f3c310e186692f65cc67 (patch) | |
| tree | 21891ab98c10e64ef9261c69b2d494f42cda66f1 /training/configs | |
| parent | a548e421314908771ce9e413d9fa4e205943cceb (diff) | |
Working on mapping
Diffstat (limited to 'training/configs')
| -rw-r--r-- | training/configs/image_transformer.yaml | 28 | 
1 files changed, 21 insertions, 7 deletions
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml index bedcbb5..88c05c2 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/image_transformer.yaml @@ -1,7 +1,7 @@  seed: 4711  network: -        desc: null +        desc: Configuration of the PyTorch neural network.          type: ImageTransformer          args:                  encoder: @@ -15,20 +15,24 @@ network:                  transformer_activation: glu  model: -        desc: null +        desc: Configuration of the PyTorch Lightning model.          type: LitTransformerModel          args:                  optimizer:                          type: MADGRAD                          args: -                                lr: 1.0e-2 +                                lr: 1.0e-3                                  momentum: 0.9                                  weight_decay: 0                                  eps: 1.0e-6                  lr_scheduler: -                        type: CosineAnnealingLR +                        type: OneCycle                          args: -                                T_max: 512 +                                interval: &interval step +                                max_lr: 1.0e-3 +                                three_phase: true +                                epochs: 512 +                                steps_per_epoch: 1246 # num_samples / batch_size                  criterion:                          type: CrossEntropyLoss                          args: @@ -39,7 +43,7 @@ model:                  mapping: sentence_piece  data: -        desc: null +        desc: Configuration of the training/test data.          type: IAMExtendedParagraphs          args:                  batch_size: 16 @@ -52,6 +56,16 @@ callbacks:            args:                    monitor: val_loss                    mode: min +        - type: StochasticWeightAveraging +          args: +                  swa_epoch_start: 0.8 +                  swa_lrs: 0.05 +                  annealing_epochs: 10 +                  annealing_strategy: cos +                  device: null +        - type: LearningRateMonitor +          args: +                  logging_interval: *interval          - type: EarlyStopping            args:                    monitor: val_loss @@ -59,7 +73,7 @@ callbacks:                    patience: 10  trainer: -        desc: null +        desc: Configuration of the PyTorch Lightning Trainer.          args:                  stochastic_weight_avg: true                  auto_scale_batch_size: binsearch  |