diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 37 | 
1 files changed, 22 insertions, 15 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 32f5763..859117f 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -10,7 +10,7 @@ defaults:    - override /lr_schedulers: null    - override /optimizers: null -epochs: &epochs 512 +epochs: &epochs 600  ignore_index: &ignore_index 3  num_classes: &num_classes 58  max_output_len: &max_output_len 682 @@ -29,7 +29,7 @@ callbacks:    stochastic_weight_averaging:      _target_: pytorch_lightning.callbacks.StochasticWeightAveraging      swa_epoch_start: 0.75 -    swa_lrs: 3.0e-5 +    swa_lrs: 1.0e-5      annealing_epochs: 10      annealing_strategy: cos      device: null @@ -37,24 +37,30 @@ callbacks:  optimizers:    madgrad:      _target_: madgrad.MADGRAD -    lr: 3.0e-4 +    lr: 1.5e-4      momentum: 0.9 -    weight_decay: 5.0e-6 +    weight_decay: 0.0      eps: 1.0e-6      parameters: network  lr_schedulers:    network: -    _target_: torch.optim.lr_scheduler.CosineAnnealingLR -    T_max: *epochs -    eta_min: 1.0e-5 -    last_epoch: -1 +    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau +    mode: min +    factor: 0.1 +    patience: 10 +    threshold: 1.0e-4 +    threshold_mode: rel +    cooldown: 0 +    min_lr: 1.0e-5 +    eps: 1.0e-8 +    verbose: false      interval: epoch      monitor: val/loss  datamodule:    _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs -  batch_size: 4 +  batch_size: 6    num_workers: 12    train_fraction: 0.8    pin_memory: true @@ -66,10 +72,10 @@ rotary_embedding: &rotary_embedding      dim: 64  attn: &attn -  dim: &hidden_dim 192 +  dim: &hidden_dim 256    num_heads: 4    dim_head: 64 -  dropout_rate: &dropout_rate 0.5 +  dropout_rate: &dropout_rate 0.25  network:    _target_: text_recognizer.networks.conv_transformer.ConvTransformer @@ -78,11 +84,12 @@ network:    num_classes: *num_classes    pad_index: *ignore_index    encoder: -    _target_: text_recognizer.networks.efficientnet.EfficientNet -    arch: b1 +    _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet +    arch: b0      stochastic_dropout_rate: 0.2      bn_momentum: 0.99      bn_eps: 1.0e-3 +    depth: 7    decoder:      depth: 6      _target_: text_recognizer.networks.transformer.layers.Decoder @@ -109,13 +116,13 @@ network:    pixel_pos_embedding:      _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding      dim: *hidden_dim -    shape: &shape [36, 40] +    shape: &shape [18, 20]    axial_encoder:      _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder      dim: *hidden_dim      heads: 4      shape: *shape -    depth: 1 +    depth: 2      dim_head: 64      dim_index: 1  |