summaryrefslogtreecommitdiff
path: root/training/configs/image_transformer.yaml
diff options
context:
space:
mode:
Diffstat (limited to 'training/configs/image_transformer.yaml')
-rw-r--r--training/configs/image_transformer.yaml28
1 files changed, 21 insertions, 7 deletions
diff --git a/training/configs/image_transformer.yaml b/training/configs/image_transformer.yaml
index bedcbb5..88c05c2 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/image_transformer.yaml
@@ -1,7 +1,7 @@
seed: 4711
network:
- desc: null
+ desc: Configuration of the PyTorch neural network.
type: ImageTransformer
args:
encoder:
@@ -15,20 +15,24 @@ network:
transformer_activation: glu
model:
- desc: null
+ desc: Configuration of the PyTorch Lightning model.
type: LitTransformerModel
args:
optimizer:
type: MADGRAD
args:
- lr: 1.0e-2
+ lr: 1.0e-3
momentum: 0.9
weight_decay: 0
eps: 1.0e-6
lr_scheduler:
- type: CosineAnnealingLR
+ type: OneCycle
args:
- T_max: 512
+ interval: &interval step
+ max_lr: 1.0e-3
+ three_phase: true
+ epochs: 512
+ steps_per_epoch: 1246 # num_samples / batch_size
criterion:
type: CrossEntropyLoss
args:
@@ -39,7 +43,7 @@ model:
mapping: sentence_piece
data:
- desc: null
+ desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
batch_size: 16
@@ -52,6 +56,16 @@ callbacks:
args:
monitor: val_loss
mode: min
+ - type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+ - type: LearningRateMonitor
+ args:
+ logging_interval: *interval
- type: EarlyStopping
args:
monitor: val_loss
@@ -59,7 +73,7 @@ callbacks:
patience: 10
trainer:
- desc: null
+ desc: Configuration of the PyTorch Lightning Trainer.
args:
stochastic_weight_avg: true
auto_scale_batch_size: binsearch