diff options
Diffstat (limited to 'training/configs')
-rw-r--r-- | training/configs/image_transformer.yaml | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml index bedcbb5..88c05c2 100644 --- a/training/configs/image_transformer.yaml +++ b/training/configs/image_transformer.yaml @@ -1,7 +1,7 @@ seed: 4711 network: - desc: null + desc: Configuration of the PyTorch neural network. type: ImageTransformer args: encoder: @@ -15,20 +15,24 @@ network: transformer_activation: glu model: - desc: null + desc: Configuration of the PyTorch Lightning model. type: LitTransformerModel args: optimizer: type: MADGRAD args: - lr: 1.0e-2 + lr: 1.0e-3 momentum: 0.9 weight_decay: 0 eps: 1.0e-6 lr_scheduler: - type: CosineAnnealingLR + type: OneCycle args: - T_max: 512 + interval: &interval step + max_lr: 1.0e-3 + three_phase: true + epochs: 512 + steps_per_epoch: 1246 # num_samples / batch_size criterion: type: CrossEntropyLoss args: @@ -39,7 +43,7 @@ model: mapping: sentence_piece data: - desc: null + desc: Configuration of the training/test data. type: IAMExtendedParagraphs args: batch_size: 16 @@ -52,6 +56,16 @@ callbacks: args: monitor: val_loss mode: min + - type: StochasticWeightAveraging + args: + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null + - type: LearningRateMonitor + args: + logging_interval: *interval - type: EarlyStopping args: monitor: val_loss @@ -59,7 +73,7 @@ callbacks: patience: 10 trainer: - desc: null + desc: Configuration of the PyTorch Lightning Trainer. args: stochastic_weight_avg: true auto_scale_batch_size: binsearch |