From e643e0c61ab33ce1bb8cfdebc92fc0670c82afda Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:48:18 +0200 Subject: Update configs --- training/conf/network/convformer_lines.yaml | 31 +++++++++++++++ training/conf/network/convnext.yaml | 16 ++++---- training/conf/network/mammut_cvit_lines.yaml | 51 +++++++++++++++++++++++++ training/conf/network/mammut_lines.yaml | 19 ++++++---- training/conf/network/vit_lines.yaml | 56 +++++++++++----------------- 5 files changed, 124 insertions(+), 49 deletions(-) create mode 100644 training/conf/network/convformer_lines.yaml create mode 100644 training/conf/network/mammut_cvit_lines.yaml (limited to 'training/conf/network') diff --git a/training/conf/network/convformer_lines.yaml b/training/conf/network/convformer_lines.yaml new file mode 100644 index 0000000..ef9c831 --- /dev/null +++ b/training/conf/network/convformer_lines.yaml @@ -0,0 +1,31 @@ +_target_: text_recognizer.network.convformer.Convformer +image_height: 7 +image_width: 128 +patch_height: 1 +patch_width: 1 +dim: &dim 512 +num_classes: &num_classes 57 +encoder: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 16 + dim_mults: [2, 8, 32] + depths: [2, 2, 2] + attn: null +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 12 + dim_head: 64 + depth: 6 + dropout_rate: 0. + one_kv_head: true +token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true +tie_embeddings: false +pad_index: 3 +channels: 512 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 40343a7..bcbc78e 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,15 +1,15 @@ _target_: text_recognizer.network.convnext.convnext.ConvNext dim: 8 -dim_mults: [2, 8] -depths: [2, 2] +dim_mults: [2, 8, 8, 8] +depths: [2, 2, 2, 2] attn: _target_: text_recognizer.network.convnext.transformer.Transformer - attn: - _target_: text_recognizer.network.convnext.transformer.Attention - dim: 64 - heads: 4 - dim_head: 64 - scale: 8 + attn: null + # _target_: text_recognizer.network.convnext.transformer.Attention + # dim: 64 + # heads: 4 + # dim_head: 64 + # scale: 8 ff: _target_: text_recognizer.network.convnext.transformer.FeedForward dim: 64 diff --git a/training/conf/network/mammut_cvit_lines.yaml b/training/conf/network/mammut_cvit_lines.yaml new file mode 100644 index 0000000..75fcccb --- /dev/null +++ b/training/conf/network/mammut_cvit_lines.yaml @@ -0,0 +1,51 @@ +_target_: text_recognizer.network.mammut.MaMMUT +encoder: + _target_: text_recognizer.network.cvit.CVit + image_height: 7 + image_width: 128 + patch_height: 7 + patch_width: 1 + dim: &dim 512 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 8 + dim_head: 64 + ff_mult: 4 + depth: 2 + dropout_rate: 0.5 + use_rotary_emb: true + one_kv_head: true + stem: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 16 + dim_mults: [2, 8, 32] + depths: [2, 2, 4] + attn: null + channels: 512 +image_attn_pool: + _target_: text_recognizer.network.transformer.attention.Attention + dim: *dim + heads: 8 + causal: false + dim_head: 64 + ff_mult: 4 + dropout_rate: 0.0 + use_flash: true + norm_context: true + use_rotary_emb: false + one_kv_head: true +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.5 + one_kv_head: true +dim: *dim +dim_latent: *dim +num_tokens: 57 +pad_index: 3 +num_image_queries: 64 diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml index f1c73d0..0b27f09 100644 --- a/training/conf/network/mammut_lines.yaml +++ b/training/conf/network/mammut_lines.yaml @@ -4,17 +4,20 @@ encoder: image_height: 56 image_width: 1024 patch_height: 56 - patch_width: 8 + patch_width: 2 dim: &dim 512 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - heads: 12 + heads: 16 dim_head: 64 ff_mult: 4 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + use_rotary_emb: true + one_kv_head: true channels: 1 + patch_dropout: 0.5 image_attn_pool: _target_: text_recognizer.network.transformer.attention.Attention dim: *dim @@ -25,7 +28,8 @@ image_attn_pool: dropout_rate: 0.0 use_flash: true norm_context: true - rotary_emb: null + use_rotary_emb: false + one_kv_head: true decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim @@ -33,9 +37,10 @@ decoder: heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + one_kv_head: true dim: *dim dim_latent: *dim -num_tokens: 58 +num_tokens: 57 pad_index: 3 -num_image_queries: 256 +num_image_queries: 128 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index 638dae1..a8045c2 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -1,51 +1,39 @@ -_target_: text_recognizer.network.convformer.Convformer -image_height: 7 -image_width: 128 -patch_height: 7 -patch_width: 1 +_target_: text_recognizer.network.transformer.transformer.Transformer dim: &dim 768 -num_classes: &num_classes 58 +num_classes: &num_classes 57 encoder: - _target_: text_recognizer.network.transformer.encoder.Encoder + _target_: text_recognizer.network.transformer.vit.Vit + image_height: 56 + image_width: 1024 + patch_height: 56 + patch_width: 8 dim: *dim - inner_dim: 3072 - ff_mult: 4 - heads: 12 - dim_head: 64 - depth: 4 - dropout_rate: 0.1 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 16 + dim_head: 64 + ff_mult: 4 + depth: 6 + dropout_rate: 0. + use_rotary_emb: true + one_kv_head: false + channels: 1 + patch_dropout: 0.4 decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim - inner_dim: 3072 ff_mult: 4 heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + one_kv_head: false token_embedding: _target_: "text_recognizer.network.transformer.embedding.token.\ TokenEmbedding" num_tokens: *num_classes dim: *dim use_l2: true -tie_embeddings: true +tie_embeddings: false pad_index: 3 -channels: 64 -stem: - _target_: text_recognizer.network.convnext.convnext.ConvNext - dim: 8 - dim_mults: [2, 8, 8] - depths: [2, 2, 2] - attn: null - # _target_: text_recognizer.network.convnext.transformer.Transformer - # attn: - # _target_: text_recognizer.network.convnext.transformer.Attention - # dim: 64 - # heads: 4 - # dim_head: 64 - # scale: 8 - # ff: - # _target_: text_recognizer.network.convnext.transformer.FeedForward - # dim: 64 - # mult: 4 -- cgit v1.2.3-70-g09d2