From e643e0c61ab33ce1bb8cfdebc92fc0670c82afda Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 15 Apr 2024 21:48:18 +0200
Subject: Update configs

---
 training/conf/network/convformer_lines.yaml  | 31 +++++++++++++++
 training/conf/network/convnext.yaml          | 16 ++++----
 training/conf/network/mammut_cvit_lines.yaml | 51 +++++++++++++++++++++++++
 training/conf/network/mammut_lines.yaml      | 19 ++++++----
 training/conf/network/vit_lines.yaml         | 56 +++++++++++-----------------
 5 files changed, 124 insertions(+), 49 deletions(-)
 create mode 100644 training/conf/network/convformer_lines.yaml
 create mode 100644 training/conf/network/mammut_cvit_lines.yaml

(limited to 'training/conf/network')

diff --git a/training/conf/network/convformer_lines.yaml b/training/conf/network/convformer_lines.yaml
new file mode 100644
index 0000000..ef9c831
--- /dev/null
+++ b/training/conf/network/convformer_lines.yaml
@@ -0,0 +1,31 @@
+_target_: text_recognizer.network.convformer.Convformer
+image_height: 7
+image_width: 128
+patch_height: 1
+patch_width: 1
+dim: &dim 512
+num_classes: &num_classes 57
+encoder:
+  _target_: text_recognizer.network.convnext.convnext.ConvNext
+  dim: 16
+  dim_mults: [2, 8, 32]
+  depths: [2, 2, 2]
+  attn: null
+decoder:
+  _target_: text_recognizer.network.transformer.decoder.Decoder
+  dim: *dim
+  ff_mult: 4
+  heads: 12
+  dim_head: 64
+  depth: 6
+  dropout_rate: 0.
+  one_kv_head: true
+token_embedding:
+  _target_: "text_recognizer.network.transformer.embedding.token.\
+    TokenEmbedding"
+  num_tokens: *num_classes
+  dim: *dim
+  use_l2: true
+tie_embeddings: false
+pad_index: 3
+channels: 512
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
index 40343a7..bcbc78e 100644
--- a/training/conf/network/convnext.yaml
+++ b/training/conf/network/convnext.yaml
@@ -1,15 +1,15 @@
 _target_: text_recognizer.network.convnext.convnext.ConvNext
 dim: 8
-dim_mults: [2, 8]
-depths: [2, 2]
+dim_mults: [2, 8, 8, 8]
+depths: [2, 2, 2, 2]
 attn:
   _target_: text_recognizer.network.convnext.transformer.Transformer
-  attn:
-    _target_: text_recognizer.network.convnext.transformer.Attention
-    dim: 64
-    heads: 4
-    dim_head: 64
-    scale: 8
+  attn: null
+    # _target_: text_recognizer.network.convnext.transformer.Attention
+    # dim: 64
+    # heads: 4
+    # dim_head: 64
+    # scale: 8
   ff:
     _target_: text_recognizer.network.convnext.transformer.FeedForward
     dim: 64
diff --git a/training/conf/network/mammut_cvit_lines.yaml b/training/conf/network/mammut_cvit_lines.yaml
new file mode 100644
index 0000000..75fcccb
--- /dev/null
+++ b/training/conf/network/mammut_cvit_lines.yaml
@@ -0,0 +1,51 @@
+_target_: text_recognizer.network.mammut.MaMMUT
+encoder:
+  _target_: text_recognizer.network.cvit.CVit
+  image_height: 7
+  image_width: 128
+  patch_height: 7
+  patch_width: 1
+  dim: &dim 512
+  encoder:
+    _target_: text_recognizer.network.transformer.encoder.Encoder
+    dim: *dim
+    heads: 8
+    dim_head: 64
+    ff_mult: 4
+    depth: 2
+    dropout_rate: 0.5
+    use_rotary_emb: true
+    one_kv_head: true
+  stem:
+    _target_: text_recognizer.network.convnext.convnext.ConvNext
+    dim: 16
+    dim_mults: [2, 8, 32]
+    depths: [2, 2, 4]
+    attn: null
+  channels: 512
+image_attn_pool:
+  _target_: text_recognizer.network.transformer.attention.Attention
+  dim: *dim
+  heads: 8
+  causal: false
+  dim_head: 64
+  ff_mult: 4
+  dropout_rate: 0.0
+  use_flash: true
+  norm_context: true
+  use_rotary_emb: false
+  one_kv_head: true
+decoder:
+  _target_: text_recognizer.network.transformer.decoder.Decoder
+  dim: *dim
+  ff_mult: 4
+  heads: 8
+  dim_head: 64
+  depth: 6
+  dropout_rate: 0.5
+  one_kv_head: true
+dim: *dim
+dim_latent: *dim
+num_tokens: 57
+pad_index: 3
+num_image_queries: 64
diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml
index f1c73d0..0b27f09 100644
--- a/training/conf/network/mammut_lines.yaml
+++ b/training/conf/network/mammut_lines.yaml
@@ -4,17 +4,20 @@ encoder:
   image_height: 56
   image_width: 1024
   patch_height: 56
-  patch_width: 8
+  patch_width: 2
   dim: &dim 512
   encoder:
     _target_: text_recognizer.network.transformer.encoder.Encoder
     dim: *dim
-    heads: 12
+    heads: 16
     dim_head: 64
     ff_mult: 4
     depth: 6
-    dropout_rate: 0.1
+    dropout_rate: 0.
+    use_rotary_emb: true
+    one_kv_head: true
   channels: 1
+  patch_dropout: 0.5
 image_attn_pool:
   _target_: text_recognizer.network.transformer.attention.Attention
   dim: *dim
@@ -25,7 +28,8 @@ image_attn_pool:
   dropout_rate: 0.0
   use_flash: true
   norm_context: true
-  rotary_emb: null
+  use_rotary_emb: false
+  one_kv_head: true
 decoder:
   _target_: text_recognizer.network.transformer.decoder.Decoder
   dim: *dim
@@ -33,9 +37,10 @@ decoder:
   heads: 12
   dim_head: 64
   depth: 6
-  dropout_rate: 0.1
+  dropout_rate: 0.
+  one_kv_head: true
 dim: *dim
 dim_latent: *dim
-num_tokens: 58
+num_tokens: 57
 pad_index: 3
-num_image_queries: 256
+num_image_queries: 128
diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml
index 638dae1..a8045c2 100644
--- a/training/conf/network/vit_lines.yaml
+++ b/training/conf/network/vit_lines.yaml
@@ -1,51 +1,39 @@
-_target_: text_recognizer.network.convformer.Convformer
-image_height: 7
-image_width: 128
-patch_height: 7
-patch_width: 1
+_target_: text_recognizer.network.transformer.transformer.Transformer
 dim: &dim 768
-num_classes: &num_classes 58
+num_classes: &num_classes 57
 encoder:
-  _target_: text_recognizer.network.transformer.encoder.Encoder
+  _target_: text_recognizer.network.transformer.vit.Vit
+  image_height: 56
+  image_width: 1024
+  patch_height: 56
+  patch_width: 8
   dim: *dim
-  inner_dim: 3072
-  ff_mult: 4
-  heads: 12
-  dim_head: 64
-  depth: 4
-  dropout_rate: 0.1
+  encoder:
+    _target_: text_recognizer.network.transformer.encoder.Encoder
+    dim: *dim
+    heads: 16
+    dim_head: 64
+    ff_mult: 4
+    depth: 6
+    dropout_rate: 0.
+    use_rotary_emb: true
+    one_kv_head: false
+  channels: 1
+  patch_dropout: 0.4
 decoder:
   _target_: text_recognizer.network.transformer.decoder.Decoder
   dim: *dim
-  inner_dim: 3072
   ff_mult: 4
   heads: 12
   dim_head: 64
   depth: 6
-  dropout_rate: 0.1
+  dropout_rate: 0.
+  one_kv_head: false
 token_embedding:
   _target_: "text_recognizer.network.transformer.embedding.token.\
     TokenEmbedding"
   num_tokens: *num_classes
   dim: *dim
   use_l2: true
-tie_embeddings: true
+tie_embeddings: false
 pad_index: 3
-channels: 64
-stem:
-  _target_: text_recognizer.network.convnext.convnext.ConvNext
-  dim: 8
-  dim_mults: [2, 8, 8]
-  depths: [2, 2, 2]
-  attn: null
-    # _target_: text_recognizer.network.convnext.transformer.Transformer
-    # attn:
-    #   _target_: text_recognizer.network.convnext.transformer.Attention
-    #   dim: 64
-    #   heads: 4
-    #   dim_head: 64
-    #   scale: 8
-    # ff:
-    #   _target_: text_recognizer.network.convnext.transformer.FeedForward
-    #   dim: 64
-    #   mult: 4
-- 
cgit v1.2.3-70-g09d2