From 56dc112cfb649217cd624b4ff305e2db83a383b7 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 11 Sep 2023 22:15:26 +0200
Subject: Update configs

---
 training/conf/network/convnext.yaml     | 16 ++++++++++
 training/conf/network/mammut_lines.yaml | 41 +++++++++++++++++++++++++
 training/conf/network/vit_lines.yaml    | 54 +++++++++++++++++++++------------
 3 files changed, 91 insertions(+), 20 deletions(-)
 create mode 100644 training/conf/network/convnext.yaml
 create mode 100644 training/conf/network/mammut_lines.yaml

(limited to 'training/conf/network')

diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
new file mode 100644
index 0000000..40343a7
--- /dev/null
+++ b/training/conf/network/convnext.yaml
@@ -0,0 +1,16 @@
+_target_: text_recognizer.network.convnext.convnext.ConvNext
+dim: 8
+dim_mults: [2, 8]
+depths: [2, 2]
+attn:
+  _target_: text_recognizer.network.convnext.transformer.Transformer
+  attn:
+    _target_: text_recognizer.network.convnext.transformer.Attention
+    dim: 64
+    heads: 4
+    dim_head: 64
+    scale: 8
+  ff:
+    _target_: text_recognizer.network.convnext.transformer.FeedForward
+    dim: 64
+    mult: 4
diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml
new file mode 100644
index 0000000..f1c73d0
--- /dev/null
+++ b/training/conf/network/mammut_lines.yaml
@@ -0,0 +1,41 @@
+_target_: text_recognizer.network.mammut.MaMMUT
+encoder:
+  _target_: text_recognizer.network.vit.Vit
+  image_height: 56
+  image_width: 1024
+  patch_height: 56
+  patch_width: 8
+  dim: &dim 512
+  encoder:
+    _target_: text_recognizer.network.transformer.encoder.Encoder
+    dim: *dim
+    heads: 12
+    dim_head: 64
+    ff_mult: 4
+    depth: 6
+    dropout_rate: 0.1
+  channels: 1
+image_attn_pool:
+  _target_: text_recognizer.network.transformer.attention.Attention
+  dim: *dim
+  heads: 8
+  causal: false
+  dim_head: 64
+  ff_mult: 4
+  dropout_rate: 0.0
+  use_flash: true
+  norm_context: true
+  rotary_emb: null
+decoder:
+  _target_: text_recognizer.network.transformer.decoder.Decoder
+  dim: *dim
+  ff_mult: 4
+  heads: 12
+  dim_head: 64
+  depth: 6
+  dropout_rate: 0.1
+dim: *dim
+dim_latent: *dim
+num_tokens: 58
+pad_index: 3
+num_image_queries: 256
diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml
index f32cb83..638dae1 100644
--- a/training/conf/network/vit_lines.yaml
+++ b/training/conf/network/vit_lines.yaml
@@ -1,37 +1,51 @@
-_target_: text_recognizer.network.vit.VisionTransformer
-image_height: 56
-image_width: 1024
-patch_height: 28
-patch_width: 32
-dim: &dim 1024
+_target_: text_recognizer.network.convformer.Convformer
+image_height: 7
+image_width: 128
+patch_height: 7
+patch_width: 1
+dim: &dim 768
 num_classes: &num_classes 58
 encoder:
   _target_: text_recognizer.network.transformer.encoder.Encoder
   dim: *dim
-  inner_dim: 2048
-  heads: 16
+  inner_dim: 3072
+  ff_mult: 4
+  heads: 12
   dim_head: 64
-  depth: 6
-  dropout_rate: 0.0
+  depth: 4
+  dropout_rate: 0.1
 decoder:
   _target_: text_recognizer.network.transformer.decoder.Decoder
   dim: *dim
-  inner_dim: 2048
-  heads: 16
+  inner_dim: 3072
+  ff_mult: 4
+  heads: 12
   dim_head: 64
   depth: 6
-  dropout_rate: 0.0
+  dropout_rate: 0.1
 token_embedding:
   _target_: "text_recognizer.network.transformer.embedding.token.\
     TokenEmbedding"
   num_tokens: *num_classes
   dim: *dim
   use_l2: true
-pos_embedding:
-  _target_: "text_recognizer.network.transformer.embedding.absolute.\
-    AbsolutePositionalEmbedding"
-  dim: *dim
-  max_length: 89
-  use_l2: true
-tie_embeddings: false
+tie_embeddings: true
 pad_index: 3
+channels: 64
+stem:
+  _target_: text_recognizer.network.convnext.convnext.ConvNext
+  dim: 8
+  dim_mults: [2, 8, 8]
+  depths: [2, 2, 2]
+  attn: null
+    # _target_: text_recognizer.network.convnext.transformer.Transformer
+    # attn:
+    #   _target_: text_recognizer.network.convnext.transformer.Attention
+    #   dim: 64
+    #   heads: 4
+    #   dim_head: 64
+    #   scale: 8
+    # ff:
+    #   _target_: text_recognizer.network.convnext.transformer.FeedForward
+    #   dim: 64
+    #   mult: 4
-- 
cgit v1.2.3-70-g09d2