diff options
Diffstat (limited to 'training/conf/network/mammut_lines.yaml')
-rw-r--r-- | training/conf/network/mammut_lines.yaml | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml index f1c73d0..0b27f09 100644 --- a/training/conf/network/mammut_lines.yaml +++ b/training/conf/network/mammut_lines.yaml @@ -4,17 +4,20 @@ encoder: image_height: 56 image_width: 1024 patch_height: 56 - patch_width: 8 + patch_width: 2 dim: &dim 512 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - heads: 12 + heads: 16 dim_head: 64 ff_mult: 4 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + use_rotary_emb: true + one_kv_head: true channels: 1 + patch_dropout: 0.5 image_attn_pool: _target_: text_recognizer.network.transformer.attention.Attention dim: *dim @@ -25,7 +28,8 @@ image_attn_pool: dropout_rate: 0.0 use_flash: true norm_context: true - rotary_emb: null + use_rotary_emb: false + one_kv_head: true decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim @@ -33,9 +37,10 @@ decoder: heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.1 + dropout_rate: 0. + one_kv_head: true dim: *dim dim_latent: *dim -num_tokens: 58 +num_tokens: 57 pad_index: 3 -num_image_queries: 256 +num_image_queries: 128 |