From e643e0c61ab33ce1bb8cfdebc92fc0670c82afda Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:48:18 +0200 Subject: Update configs --- training/conf/experiment/convformer_lines.yaml | 58 ++++++++++++++++++++++++++ training/conf/experiment/mammut_lines.yaml | 5 ++- training/conf/experiment/vit_lines.yaml | 3 ++ training/conf/network/convformer_lines.yaml | 31 ++++++++++++++ training/conf/network/convnext.yaml | 16 +++---- training/conf/network/mammut_cvit_lines.yaml | 51 ++++++++++++++++++++++ training/conf/network/mammut_lines.yaml | 19 +++++---- training/conf/network/vit_lines.yaml | 56 ++++++++++--------------- 8 files changed, 189 insertions(+), 50 deletions(-) create mode 100644 training/conf/experiment/convformer_lines.yaml create mode 100644 training/conf/network/convformer_lines.yaml create mode 100644 training/conf/network/mammut_cvit_lines.yaml diff --git a/training/conf/experiment/convformer_lines.yaml b/training/conf/experiment/convformer_lines.yaml new file mode 100644 index 0000000..f573433 --- /dev/null +++ b/training/conf/experiment/convformer_lines.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: convformer_lines + - override /model: lit_transformer + - override /lr_scheduler: cosine_annealing + - override /optimizer: adan + +tags: [lines, vit] +epochs: &epochs 320 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +optimizer: + lr: 1.0e-3 + +# callbacks: +# stochastic_weight_averaging: +# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +# swa_epoch_start: 0.75 +# swa_lrs: 1.0e-5 +# annealing_epochs: 10 +# annealing_strategy: cos +# device: null + +lr_scheduler: + T_max: *epochs + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 diff --git a/training/conf/experiment/mammut_lines.yaml b/training/conf/experiment/mammut_lines.yaml index e74e219..eb6f765 100644 --- a/training/conf/experiment/mammut_lines.yaml +++ b/training/conf/experiment/mammut_lines.yaml @@ -39,12 +39,15 @@ lr_scheduler: T_max: *epochs datamodule: - batch_size: 8 + batch_size: 16 train_fraction: 0.95 model: max_output_len: 89 +optimizer: + lr: 1.0e-3 + trainer: fast_dev_run: false gradient_clip_val: 1.0 diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml index 08ed481..2f7731e 100644 --- a/training/conf/experiment/vit_lines.yaml +++ b/training/conf/experiment/vit_lines.yaml @@ -26,6 +26,9 @@ criterion: decoder: max_output_len: 89 +optim4izer: + lr: 1.0e-3 + # callbacks: # stochastic_weight_averaging: # _target_: pytorch_lightning.callbacks.StochasticWeightAveraging diff --git a/training/conf/network/convformer_lines.yaml b/training/conf/network/convformer_lines.yaml new file mode 100644 index 0000000..ef9c831 --- /dev/null +++ b/training/conf/network/convformer_lines.yaml @@ -0,0 +1,31 @@ +_target_: text_recognizer.network.convformer.Convformer +image_height: 7 +image_width: 128 +patch_height: 1 +patch_width: 1 +dim: &dim 512 +num_classes: &num_classes 57 +encoder: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 16 + dim_mults: [2, 8, 32] + depths: [2, 2, 2] + attn: null +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 12 + dim_head: 64 + depth: 6 + dropout_rate: 0. + one_kv_head: true +token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true +tie_embeddings: false +pad_index: 3 +channels: 512 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 40343a7..bcbc78e 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,15 +1,15 @@ _target_: text_recognizer.network.convnext.convnext.ConvNext dim: 8 -dim_mults: [2, 8] -depths: [2, 2] +dim_mults: [2, 8, 8, 8] +depths: [2, 2, 2, 2] attn: _target_: text_recognizer.network.convnext.transformer.Transformer - attn: - _target_: text_recognizer.network.convnext.transformer.Attention - dim: 64 - heads: 4 - dim_head: 64 - scale: 8 + attn: null + # _target_: text_recognizer.network.convnext.transformer.Attention + # dim: 64 + # heads: 4 + # dim_head: 64 + # scale: 8 ff: _target_: text_recognizer.network.convnext.transformer.FeedForward dim: 64 diff --git a/training/conf/network/mammut_cvit_lines.yaml b/training/conf/network/mammut_cvit_lines.yaml new file mode 100644 index 0000000..75fcccb --- /dev/null +++ b/training/conf/network/mammut_cvit_lines.yaml @@ -0,0 +1,51 @@ +_target_: text_recognizer.network.mammut.MaMMUT +encoder: + _target_: text_recognizer.network.cvit.CVit + image_height: 7 + image_width: 128 + patch_height: 7 + patch_width: 1 + dim: &dim 512 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 8 + dim_head: 64 + ff_mult: 4 + depth: 2 + dropout_rate: 0.5 + use_rotary_emb: true + one_kv_head: true + stem: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 16 + dim_mults: [2, 8, 32] + depths: [2, 2, 4] + attn: null + channels: 512 +image_attn_pool: + _target_: text_recognizer.network.transformer.attention.Attention + dim: *dim + heads: 8 + causal: false + dim_head: 64 + ff_mult: 4 + dropout_rate: 0.0 + use_flash: true + norm_context: true + use_rotary_emb: false + one_kv_head: true +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.5 + one_kv_head: true +dim: *dim +dim_latent: *dim +num_tokens: 57 +pad_index: 3 +num_image_queries: 64 diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml index f1c73d0..0b27f09 100644 --- a/training/conf/network/mammut_lines.yaml +++ b/training/conf/network/mammut_lines.yaml @@ -4,17 +4,20 @@ encoder: image_height: 56 image_width: 1024 patch_height: 56 - patch_width: 8 + patch_width: 2 dim: &dim 512 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - heads: 12 + heads: 16 dim_head: 64 ff_mult: 4 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + use_rotary_emb: true + one_kv_head: true channels: 1 + patch_dropout: 0.5 image_attn_pool: _target_: text_recognizer.network.transformer.attention.Attention dim: *dim @@ -25,7 +28,8 @@ image_attn_pool: dropout_rate: 0.0 use_flash: true norm_context: true - rotary_emb: null + use_rotary_emb: false + one_kv_head: true decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim @@ -33,9 +37,10 @@ decoder: heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + one_kv_head: true dim: *dim dim_latent: *dim -num_tokens: 58 +num_tokens: 57 pad_index: 3 -num_image_queries: 256 +num_image_queries: 128 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index 638dae1..a8045c2 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -1,51 +1,39 @@ -_target_: text_recognizer.network.convformer.Convformer -image_height: 7 -image_width: 128 -patch_height: 7 -patch_width: 1 +_target_: text_recognizer.network.transformer.transformer.Transformer dim: &dim 768 -num_classes: &num_classes 58 +num_classes: &num_classes 57 encoder: - _target_: text_recognizer.network.transformer.encoder.Encoder + _target_: text_recognizer.network.transformer.vit.Vit + image_height: 56 + image_width: 1024 + patch_height: 56 + patch_width: 8 dim: *dim - inner_dim: 3072 - ff_mult: 4 - heads: 12 - dim_head: 64 - depth: 4 - dropout_rate: 0.1 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 16 + dim_head: 64 + ff_mult: 4 + depth: 6 + dropout_rate: 0. + use_rotary_emb: true + one_kv_head: false + channels: 1 + patch_dropout: 0.4 decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim - inner_dim: 3072 ff_mult: 4 heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + one_kv_head: false token_embedding: _target_: "text_recognizer.network.transformer.embedding.token.\ TokenEmbedding" num_tokens: *num_classes dim: *dim use_l2: true -tie_embeddings: true +tie_embeddings: false pad_index: 3 -channels: 64 -stem: - _target_: text_recognizer.network.convnext.convnext.ConvNext - dim: 8 - dim_mults: [2, 8, 8] - depths: [2, 2, 2] - attn: null - # _target_: text_recognizer.network.convnext.transformer.Transformer - # attn: - # _target_: text_recognizer.network.convnext.transformer.Attention - # dim: 64 - # heads: 4 - # dim_head: 64 - # scale: 8 - # ff: - # _target_: text_recognizer.network.convnext.transformer.FeedForward - # dim: 64 - # mult: 4 -- cgit v1.2.3-70-g09d2