From 56dc112cfb649217cd624b4ff305e2db83a383b7 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 11 Sep 2023 22:15:26 +0200
Subject: Update configs

---
 training/conf/callbacks/wandb/watch.yaml         |  2 +-
 training/conf/decoder/greedy.yaml                |  2 +-
 training/conf/experiment/mammut_lines.yaml       | 55 ++++++++++++++++++++++++
 training/conf/experiment/vit_lines.yaml          | 43 ++++++------------
 training/conf/logger/csv.yaml                    |  4 --
 training/conf/lr_scheduler/cosine_annealing.yaml |  2 +-
 training/conf/model/lit_mammut.yaml              |  4 ++
 training/conf/network/convnext.yaml              | 16 +++++++
 training/conf/network/mammut_lines.yaml          | 41 ++++++++++++++++++
 training/conf/network/vit_lines.yaml             | 54 ++++++++++++++---------
 training/conf/optimizer/adamw.yaml               |  5 +++
 training/conf/optimizer/adan.yaml                |  4 ++
 training/conf/optimizer/lion.yaml                |  5 +++
 13 files changed, 180 insertions(+), 57 deletions(-)
 create mode 100644 training/conf/experiment/mammut_lines.yaml
 delete mode 100644 training/conf/logger/csv.yaml
 create mode 100644 training/conf/model/lit_mammut.yaml
 create mode 100644 training/conf/network/convnext.yaml
 create mode 100644 training/conf/network/mammut_lines.yaml
 create mode 100644 training/conf/optimizer/adamw.yaml
 create mode 100644 training/conf/optimizer/adan.yaml
 create mode 100644 training/conf/optimizer/lion.yaml

diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml
index 1f60978..945b624 100644
--- a/training/conf/callbacks/wandb/watch.yaml
+++ b/training/conf/callbacks/wandb/watch.yaml
@@ -2,4 +2,4 @@ watch_model:
   _target_: callbacks.wandb.WatchModel
   log_params: gradients
   log_freq: 100
-  log_graph: true
+  log_graph: false
diff --git a/training/conf/decoder/greedy.yaml b/training/conf/decoder/greedy.yaml
index a88b5a6..44ef53e 100644
--- a/training/conf/decoder/greedy.yaml
+++ b/training/conf/decoder/greedy.yaml
@@ -1,2 +1,2 @@
-_target_: text_recognizer.model.greedy_decoder.GreedyDecoder
+_target_: text_recognizer.decoder.greedy_decoder.GreedyDecoder
 max_output_len: 682
diff --git a/training/conf/experiment/mammut_lines.yaml b/training/conf/experiment/mammut_lines.yaml
new file mode 100644
index 0000000..e74e219
--- /dev/null
+++ b/training/conf/experiment/mammut_lines.yaml
@@ -0,0 +1,55 @@
+# @package _global_
+
+defaults:
+  - override /criterion: cross_entropy
+  - override /callbacks: htr
+  - override /datamodule: iam_lines
+  - override /network: mammut_lines
+  - override /model: lit_mammut
+  - override /lr_scheduler: cosine_annealing
+  - override /optimizer: adan
+
+tags: [lines, vit]
+epochs: &epochs 320
+ignore_index: &ignore_index 3
+# summary: [[1, 1, 56, 1024], [1, 89]]
+
+logger:
+  wandb:
+    tags: ${tags}
+
+criterion:
+  ignore_index: *ignore_index
+  # label_smoothing: 0.05
+
+
+decoder:
+  max_output_len: 89
+
+# callbacks:
+#   stochastic_weight_averaging:
+#     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+#     swa_epoch_start: 0.75
+#     swa_lrs: 1.0e-5
+#     annealing_epochs: 10
+#     annealing_strategy: cos
+#     device: null
+
+lr_scheduler:
+  T_max: *epochs
+
+datamodule:
+  batch_size: 8
+  train_fraction: 0.95
+
+model:
+  max_output_len: 89
+
+trainer:
+  fast_dev_run: false
+  gradient_clip_val: 1.0
+  max_epochs: *epochs
+  accumulate_grad_batches: 1
+  limit_train_batches: 1.0
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index f3049ea..08ed481 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -6,11 +6,11 @@ defaults:
   - override /datamodule: iam_lines
   - override /network: vit_lines
   - override /model: lit_transformer
-  - override /lr_scheduler: null
-  - override /optimizer: null
+  - override /lr_scheduler: cosine_annealing
+  - override /optimizer: adan
 
 tags: [lines, vit]
-epochs: &epochs 128
+epochs: &epochs 320
 ignore_index: &ignore_index 3
 # summary: [[1, 1, 56, 1024], [1, 89]]
 
@@ -26,37 +26,20 @@ criterion:
 decoder:
   max_output_len: 89
 
-callbacks:
-  stochastic_weight_averaging:
-    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
-    swa_epoch_start: 0.75
-    swa_lrs: 1.0e-5
-    annealing_epochs: 10
-    annealing_strategy: cos
-    device: null
-
-optimizer:
-  _target_: adan_pytorch.Adan
-  lr: 3.0e-4
-  betas: [0.02, 0.08, 0.01]
-  weight_decay: 0.02
+# callbacks:
+#   stochastic_weight_averaging:
+#     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+#     swa_epoch_start: 0.75
+#     swa_lrs: 1.0e-5
+#     annealing_epochs: 10
+#     annealing_strategy: cos
+#     device: null
 
 lr_scheduler:
-  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
-  mode: min
-  factor: 0.8
-  patience: 10
-  threshold: 1.0e-4
-  threshold_mode: rel
-  cooldown: 0
-  min_lr: 1.0e-5
-  eps: 1.0e-8
-  verbose: false
-  interval: epoch
-  monitor: val/cer
+  T_max: *epochs
 
 datamodule:
-  batch_size: 16
+  batch_size: 8
   train_fraction: 0.95
 
 model:
diff --git a/training/conf/logger/csv.yaml b/training/conf/logger/csv.yaml
deleted file mode 100644
index 9fa6cad..0000000
--- a/training/conf/logger/csv.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-csv:
-  _target_: pytorch_lightning.loggers.CSVLogger
-  name: null
-  save_dir: "."
diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml
index e8364f0..36e87d4 100644
--- a/training/conf/lr_scheduler/cosine_annealing.yaml
+++ b/training/conf/lr_scheduler/cosine_annealing.yaml
@@ -2,6 +2,6 @@ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
 T_max: 256
 eta_min: 0.0
 last_epoch: -1
-
+verbose: false
 interval: epoch
 monitor: val/loss
diff --git a/training/conf/model/lit_mammut.yaml b/training/conf/model/lit_mammut.yaml
new file mode 100644
index 0000000..2495692
--- /dev/null
+++ b/training/conf/model/lit_mammut.yaml
@@ -0,0 +1,4 @@
+_target_: text_recognizer.model.mammut.LitMaMMUT
+max_output_len: 682
+caption_loss_weight: 1.0
+contrastive_loss_weight: 1.0
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
new file mode 100644
index 0000000..40343a7
--- /dev/null
+++ b/training/conf/network/convnext.yaml
@@ -0,0 +1,16 @@
+_target_: text_recognizer.network.convnext.convnext.ConvNext
+dim: 8
+dim_mults: [2, 8]
+depths: [2, 2]
+attn:
+  _target_: text_recognizer.network.convnext.transformer.Transformer
+  attn:
+    _target_: text_recognizer.network.convnext.transformer.Attention
+    dim: 64
+    heads: 4
+    dim_head: 64
+    scale: 8
+  ff:
+    _target_: text_recognizer.network.convnext.transformer.FeedForward
+    dim: 64
+    mult: 4
diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml
new file mode 100644
index 0000000..f1c73d0
--- /dev/null
+++ b/training/conf/network/mammut_lines.yaml
@@ -0,0 +1,41 @@
+_target_: text_recognizer.network.mammut.MaMMUT
+encoder:
+  _target_: text_recognizer.network.vit.Vit
+  image_height: 56
+  image_width: 1024
+  patch_height: 56
+  patch_width: 8
+  dim: &dim 512
+  encoder:
+    _target_: text_recognizer.network.transformer.encoder.Encoder
+    dim: *dim
+    heads: 12
+    dim_head: 64
+    ff_mult: 4
+    depth: 6
+    dropout_rate: 0.1
+  channels: 1
+image_attn_pool:
+  _target_: text_recognizer.network.transformer.attention.Attention
+  dim: *dim
+  heads: 8
+  causal: false
+  dim_head: 64
+  ff_mult: 4
+  dropout_rate: 0.0
+  use_flash: true
+  norm_context: true
+  rotary_emb: null
+decoder:
+  _target_: text_recognizer.network.transformer.decoder.Decoder
+  dim: *dim
+  ff_mult: 4
+  heads: 12
+  dim_head: 64
+  depth: 6
+  dropout_rate: 0.1
+dim: *dim
+dim_latent: *dim
+num_tokens: 58
+pad_index: 3
+num_image_queries: 256
diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml
index f32cb83..638dae1 100644
--- a/training/conf/network/vit_lines.yaml
+++ b/training/conf/network/vit_lines.yaml
@@ -1,37 +1,51 @@
-_target_: text_recognizer.network.vit.VisionTransformer
-image_height: 56
-image_width: 1024
-patch_height: 28
-patch_width: 32
-dim: &dim 1024
+_target_: text_recognizer.network.convformer.Convformer
+image_height: 7
+image_width: 128
+patch_height: 7
+patch_width: 1
+dim: &dim 768
 num_classes: &num_classes 58
 encoder:
   _target_: text_recognizer.network.transformer.encoder.Encoder
   dim: *dim
-  inner_dim: 2048
-  heads: 16
+  inner_dim: 3072
+  ff_mult: 4
+  heads: 12
   dim_head: 64
-  depth: 6
-  dropout_rate: 0.0
+  depth: 4
+  dropout_rate: 0.1
 decoder:
   _target_: text_recognizer.network.transformer.decoder.Decoder
   dim: *dim
-  inner_dim: 2048
-  heads: 16
+  inner_dim: 3072
+  ff_mult: 4
+  heads: 12
   dim_head: 64
   depth: 6
-  dropout_rate: 0.0
+  dropout_rate: 0.1
 token_embedding:
   _target_: "text_recognizer.network.transformer.embedding.token.\
     TokenEmbedding"
   num_tokens: *num_classes
   dim: *dim
   use_l2: true
-pos_embedding:
-  _target_: "text_recognizer.network.transformer.embedding.absolute.\
-    AbsolutePositionalEmbedding"
-  dim: *dim
-  max_length: 89
-  use_l2: true
-tie_embeddings: false
+tie_embeddings: true
 pad_index: 3
+channels: 64
+stem:
+  _target_: text_recognizer.network.convnext.convnext.ConvNext
+  dim: 8
+  dim_mults: [2, 8, 8]
+  depths: [2, 2, 2]
+  attn: null
+    # _target_: text_recognizer.network.convnext.transformer.Transformer
+    # attn:
+    #   _target_: text_recognizer.network.convnext.transformer.Attention
+    #   dim: 64
+    #   heads: 4
+    #   dim_head: 64
+    #   scale: 8
+    # ff:
+    #   _target_: text_recognizer.network.convnext.transformer.FeedForward
+    #   dim: 64
+    #   mult: 4
diff --git a/training/conf/optimizer/adamw.yaml b/training/conf/optimizer/adamw.yaml
new file mode 100644
index 0000000..568c67b
--- /dev/null
+++ b/training/conf/optimizer/adamw.yaml
@@ -0,0 +1,5 @@
+_target_: torch.optim.AdamW
+lr: 1.5e-4
+betas: [0.95, 0.98]
+weight_decay: 1.0e-2
+eps: 1.0e-6
diff --git a/training/conf/optimizer/adan.yaml b/training/conf/optimizer/adan.yaml
new file mode 100644
index 0000000..6950a86
--- /dev/null
+++ b/training/conf/optimizer/adan.yaml
@@ -0,0 +1,4 @@
+_target_: adan_pytorch.Adan
+lr: 2.0e-4
+betas: [0.02, 0.08, 0.01]
+weight_decay: 0.02
diff --git a/training/conf/optimizer/lion.yaml b/training/conf/optimizer/lion.yaml
new file mode 100644
index 0000000..cbf386a
--- /dev/null
+++ b/training/conf/optimizer/lion.yaml
@@ -0,0 +1,5 @@
+_target_: lion_pytorch.Lion
+lr: 5e-5
+betas: [0.95, 0.99]
+weight_decay: 0.1
+use_triton: true
-- 
cgit v1.2.3-70-g09d2