diff options
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/convnext.yaml | 16 | ||||
-rw-r--r-- | training/conf/network/mammut_lines.yaml | 41 | ||||
-rw-r--r-- | training/conf/network/vit_lines.yaml | 54 |
3 files changed, 91 insertions, 20 deletions
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml new file mode 100644 index 0000000..40343a7 --- /dev/null +++ b/training/conf/network/convnext.yaml @@ -0,0 +1,16 @@ +_target_: text_recognizer.network.convnext.convnext.ConvNext +dim: 8 +dim_mults: [2, 8] +depths: [2, 2] +attn: + _target_: text_recognizer.network.convnext.transformer.Transformer + attn: + _target_: text_recognizer.network.convnext.transformer.Attention + dim: 64 + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.network.convnext.transformer.FeedForward + dim: 64 + mult: 4 diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml new file mode 100644 index 0000000..f1c73d0 --- /dev/null +++ b/training/conf/network/mammut_lines.yaml @@ -0,0 +1,41 @@ +_target_: text_recognizer.network.mammut.MaMMUT +encoder: + _target_: text_recognizer.network.vit.Vit + image_height: 56 + image_width: 1024 + patch_height: 56 + patch_width: 8 + dim: &dim 512 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 12 + dim_head: 64 + ff_mult: 4 + depth: 6 + dropout_rate: 0.1 + channels: 1 +image_attn_pool: + _target_: text_recognizer.network.transformer.attention.Attention + dim: *dim + heads: 8 + causal: false + dim_head: 64 + ff_mult: 4 + dropout_rate: 0.0 + use_flash: true + norm_context: true + rotary_emb: null +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 12 + dim_head: 64 + depth: 6 + dropout_rate: 0.1 +dim: *dim +dim_latent: *dim +num_tokens: 58 +pad_index: 3 +num_image_queries: 256 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index f32cb83..638dae1 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -1,37 +1,51 @@ -_target_: text_recognizer.network.vit.VisionTransformer -image_height: 56 -image_width: 1024 -patch_height: 28 -patch_width: 32 -dim: &dim 1024 +_target_: text_recognizer.network.convformer.Convformer +image_height: 7 +image_width: 128 +patch_height: 7 +patch_width: 1 +dim: &dim 768 num_classes: &num_classes 58 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - inner_dim: 2048 - heads: 16 + inner_dim: 3072 + ff_mult: 4 + heads: 12 dim_head: 64 - depth: 6 - dropout_rate: 0.0 + depth: 4 + dropout_rate: 0.1 decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim - inner_dim: 2048 - heads: 16 + inner_dim: 3072 + ff_mult: 4 + heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.0 + dropout_rate: 0.1 token_embedding: _target_: "text_recognizer.network.transformer.embedding.token.\ TokenEmbedding" num_tokens: *num_classes dim: *dim use_l2: true -pos_embedding: - _target_: "text_recognizer.network.transformer.embedding.absolute.\ - AbsolutePositionalEmbedding" - dim: *dim - max_length: 89 - use_l2: true -tie_embeddings: false +tie_embeddings: true pad_index: 3 +channels: 64 +stem: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 8 + dim_mults: [2, 8, 8] + depths: [2, 2, 2] + attn: null + # _target_: text_recognizer.network.convnext.transformer.Transformer + # attn: + # _target_: text_recognizer.network.convnext.transformer.Attention + # dim: 64 + # heads: 4 + # dim_head: 64 + # scale: 8 + # ff: + # _target_: text_recognizer.network.convnext.transformer.FeedForward + # dim: 64 + # mult: 4 |