From 3e24b92ee1bac124ea8c7bddb15236ccc5fe300d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 27 Oct 2021 22:16:39 +0200 Subject: Update to configs --- .../conf/experiment/barlow_twins_paragraphs.yaml | 4 +- training/conf/experiment/cnn_htr_wp_lines.yaml | 2 +- .../experiment/conv_transformer_paragraphs.yaml | 46 ++++++++++++---------- .../experiment/conv_transformer_paragraphs_wp.yaml | 6 +-- training/conf/mapping/characters.yaml | 2 +- training/conf/mapping/word_piece.yaml | 2 +- 6 files changed, 33 insertions(+), 29 deletions(-) (limited to 'training') diff --git a/training/conf/experiment/barlow_twins_paragraphs.yaml b/training/conf/experiment/barlow_twins_paragraphs.yaml index caefb47..9552c0b 100644 --- a/training/conf/experiment/barlow_twins_paragraphs.yaml +++ b/training/conf/experiment/barlow_twins_paragraphs.yaml @@ -37,10 +37,10 @@ optimizers: lr_schedulers: network: _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-1 + max_lr: 1.0e-3 total_steps: null epochs: *epochs - steps_per_epoch: 5053 + steps_per_epoch: 40 pct_start: 0.03 anneal_strategy: cos cycle_momentum: true diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml index 6cdd023..9f1164a 100644 --- a/training/conf/experiment/cnn_htr_wp_lines.yaml +++ b/training/conf/experiment/cnn_htr_wp_lines.yaml @@ -117,7 +117,7 @@ network: cross_attend: true pre_norm: true rotary_emb: - _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding + _target_: text_recognizer.networks.transformer.positional_encodings.rotary.RotaryEmbedding dim: 32 pixel_pos_embedding: _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 9e8bc50..4de9722 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -1,4 +1,4 @@ -# @package _glbal_ +# @package _global_ defaults: - override /mapping: null @@ -11,7 +11,7 @@ defaults: - override /optimizers: null -epochs: &epochs 1000 +epochs: &epochs 720 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 682 @@ -23,7 +23,7 @@ criterion: mapping: &mapping mapping: - _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping + _target_: text_recognizer.data.mappings.emnist.EmnistMapping extra_symbols: [ "\n" ] callbacks: @@ -38,11 +38,10 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 2.0e-4 + lr: 1.0e-4 momentum: 0.9 weight_decay: 5.0e-6 eps: 1.0e-6 - parameters: network lr_schedulers: @@ -51,18 +50,17 @@ lr_schedulers: max_lr: 1.0e-4 total_steps: null epochs: *epochs - steps_per_epoch: 632 + steps_per_epoch: 316 pct_start: 0.03 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 max_momentum: 0.95 div_factor: 25 - final_div_factor: 1.0e4 + final_div_factor: 1.0e2 three_phase: false last_epoch: -1 verbose: false - # Non-class arguments interval: step monitor: val/loss @@ -79,7 +77,7 @@ network: input_dims: [1, 576, 640] hidden_dim: &hidden_dim 256 encoder_dim: 1280 - dropout_rate: 0.2 + dropout_rate: 0.1 num_classes: *num_classes pad_index: *ignore_index encoder: @@ -90,35 +88,41 @@ network: bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.networks.transformer.layers.Decoder dim: *hidden_dim - depth: 3 + depth: 3 num_heads: 4 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: dim_head: 32 - dropout_rate: 0.2 + dropout_rate: 0.05 + local_attn_fn: text_recognizer.networks.transformer.local_attention.LocalAttention + local_attn_kwargs: + dim_head: 32 + dropout_rate: 0.05 + window_size: 11 + look_back: 2 + depth: 2 norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward ff_kwargs: dim_out: null expansion_factor: 4 glu: true - dropout_rate: 0.2 + dropout_rate: 0.05 cross_attend: true pre_norm: true rotary_emb: - _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding dim: 32 pixel_pos_embedding: - _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D - hidden_dim: *hidden_dim - max_h: 18 - max_w: 20 + _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding + dim: *hidden_dim + shape: [18, 20] token_pos_embedding: - _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding + _target_: text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding hidden_dim: *hidden_dim - dropout_rate: 0.2 + dropout_rate: 0.05 max_len: *max_output_len model: @@ -134,7 +138,7 @@ trainer: stochastic_weight_avg: true auto_scale_batch_size: binsearch auto_lr_find: false - gradient_clip_val: 0.75 + gradient_clip_val: 0.5 fast_dev_run: false gpus: 1 precision: 16 diff --git a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml index ebaa17a..91fba9a 100644 --- a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml @@ -103,14 +103,14 @@ network: attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: dim_head: 32 - dropout_rate: 0.2 + dropout_rate: 0.05 norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward ff_kwargs: dim_out: null expansion_factor: 4 glu: true - dropout_rate: 0.2 + dropout_rate: 0.05 cross_attend: true pre_norm: true rotary_emb: @@ -124,7 +124,7 @@ network: token_pos_embedding: _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding hidden_dim: *hidden_dim - dropout_rate: 0.2 + dropout_rate: 0.05 max_len: *max_output_len model: diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml index 41a26a3..d91c9e5 100644 --- a/training/conf/mapping/characters.yaml +++ b/training/conf/mapping/characters.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping +_target_: text_recognizer.data.mappings.emnist.EmnistMapping extra_symbols: [ "\n" ] diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index c005cc4..6b4dc07 100644 --- a/training/conf/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping +_target_: text_recognizer.data.mappings.word_piece.WordPieceMapping num_features: 1000 tokens: iamdb_1kwp_tokens_1000.txt lexicon: iamdb_1kwp_lex_1000.txt -- cgit v1.2.3-70-g09d2