diff options
Diffstat (limited to 'training/conf/experiment/conv_transformer_paragraphs.yaml')
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 51 |
1 files changed, 15 insertions, 36 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 00ad389..d2916e1 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -10,8 +10,7 @@ defaults: - override /lr_schedulers: null - override /optimizers: null - -epochs: &epochs 720 +epochs: &epochs 512 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 682 @@ -29,7 +28,7 @@ callbacks: stochastic_weight_averaging: _target_: pytorch_lightning.callbacks.StochasticWeightAveraging swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 + swa_lrs: 3.0e-5 annealing_epochs: 10 annealing_strategy: cos device: null @@ -37,7 +36,7 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 1.0e-4 + lr: 3.0e-4 momentum: 0.9 weight_decay: 5.0e-6 eps: 1.0e-6 @@ -45,27 +44,16 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-4 - total_steps: null - epochs: *epochs - steps_per_epoch: 1264 - pct_start: 0.01 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25 - final_div_factor: 1.0e2 - three_phase: false + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 last_epoch: -1 - verbose: false - interval: step + interval: epoch monitor: val/loss datamodule: _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs - batch_size: 6 + batch_size: 4 num_workers: 12 train_fraction: 0.8 pin_memory: true @@ -77,27 +65,25 @@ rotary_embedding: &rotary_embedding dim: 64 attn: &attn - dim: 192 + dim: &hidden_dim 192 num_heads: 4 dim_head: 64 - dropout_rate: 0.05 + dropout_rate: &dropout_rate 0.5 network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] - hidden_dim: &hidden_dim 192 + hidden_dim: *hidden_dim num_classes: *num_classes pad_index: *ignore_index encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet - arch: b0 - out_channels: 1280 + arch: b1 stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - depth: 3 - local_depth: 2 + depth: 6 _target_: text_recognizer.networks.transformer.layers.Decoder self_attn: _target_: text_recognizer.networks.transformer.attention.Attention @@ -108,13 +94,6 @@ network: _target_: text_recognizer.networks.transformer.attention.Attention << : *attn causal: false - local_self_attn: - _target_: text_recognizer.networks.transformer.local_attention.LocalAttention - << : *attn - window_size: 31 - look_back: 1 - autopad: true - << : *rotary_embedding norm: _target_: text_recognizer.networks.transformer.norm.ScaleNorm normalized_shape: *hidden_dim @@ -124,7 +103,7 @@ network: dim_out: null expansion_factor: 4 glu: true - dropout_rate: 0.05 + dropout_rate: *dropout_rate pre_norm: true pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding @@ -155,5 +134,5 @@ trainer: limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null - accumulate_grad_batches: 4 + accumulate_grad_batches: 2 overfit_batches: 0 |