diff options
Diffstat (limited to 'training/conf/experiment')
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 32f5763..859117f 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -10,7 +10,7 @@ defaults: - override /lr_schedulers: null - override /optimizers: null -epochs: &epochs 512 +epochs: &epochs 600 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 682 @@ -29,7 +29,7 @@ callbacks: stochastic_weight_averaging: _target_: pytorch_lightning.callbacks.StochasticWeightAveraging swa_epoch_start: 0.75 - swa_lrs: 3.0e-5 + swa_lrs: 1.0e-5 annealing_epochs: 10 annealing_strategy: cos device: null @@ -37,24 +37,30 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 3.0e-4 + lr: 1.5e-4 momentum: 0.9 - weight_decay: 5.0e-6 + weight_decay: 0.0 eps: 1.0e-6 parameters: network lr_schedulers: network: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: *epochs - eta_min: 1.0e-5 - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.1 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false interval: epoch monitor: val/loss datamodule: _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs - batch_size: 4 + batch_size: 6 num_workers: 12 train_fraction: 0.8 pin_memory: true @@ -66,10 +72,10 @@ rotary_embedding: &rotary_embedding dim: 64 attn: &attn - dim: &hidden_dim 192 + dim: &hidden_dim 256 num_heads: 4 dim_head: 64 - dropout_rate: &dropout_rate 0.5 + dropout_rate: &dropout_rate 0.25 network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer @@ -78,11 +84,12 @@ network: num_classes: *num_classes pad_index: *ignore_index encoder: - _target_: text_recognizer.networks.efficientnet.EfficientNet - arch: b1 + _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet + arch: b0 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 + depth: 7 decoder: depth: 6 _target_: text_recognizer.networks.transformer.layers.Decoder @@ -109,13 +116,13 @@ network: pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding dim: *hidden_dim - shape: &shape [36, 40] + shape: &shape [18, 20] axial_encoder: _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder dim: *hidden_dim heads: 4 shape: *shape - depth: 1 + depth: 2 dim_head: 64 dim_index: 1 |