From 2c9066b685d41ef0ab5ea94e938b8a30b4123656 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 15 Jun 2022 00:15:40 +0200
Subject: Update configs

---
 training/conf/callbacks/default.yaml               |  2 +-
 training/conf/callbacks/lightning/checkpoint.yaml  |  8 ++---
 .../conf/experiment/conv_transformer_lines.yaml    | 39 +++++++++++++---------
 training/conf/lr_schedulers/one_cycle.yaml         | 37 ++++++++++----------
 training/conf/model/lit_transformer.yaml           |  2 +-
 training/conf/network/conv_transformer.yaml        |  2 +-
 .../conf/network/decoder/transformer_decoder.yaml  | 30 -----------------
 training/conf/network/encoder/efficientnet.yaml    |  5 ---
 training/conf/trainer/default.yaml                 |  2 +-
 9 files changed, 49 insertions(+), 78 deletions(-)
 delete mode 100644 training/conf/network/decoder/transformer_decoder.yaml
 delete mode 100644 training/conf/network/encoder/efficientnet.yaml

diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
index 57c10a6..4d8e399 100644
--- a/training/conf/callbacks/default.yaml
+++ b/training/conf/callbacks/default.yaml
@@ -2,5 +2,5 @@ defaults:
   - lightning/checkpoint
   - lightning/learning_rate_monitor
   - wandb/watch
-  - wandb/config 
+  - wandb/config
   - wandb/checkpoints
diff --git a/training/conf/callbacks/lightning/checkpoint.yaml b/training/conf/callbacks/lightning/checkpoint.yaml
index b4101d8..9acd64f 100644
--- a/training/conf/callbacks/lightning/checkpoint.yaml
+++ b/training/conf/callbacks/lightning/checkpoint.yaml
@@ -1,9 +1,9 @@
 model_checkpoint:
   _target_: pytorch_lightning.callbacks.ModelCheckpoint
-  monitor: val/loss # name of the logged metric which determines when model is improving
-  save_top_k: 1 # save k best models (determined by above metric)
-  save_last: true # additionaly always save model from last epoch
-  mode: min # can be "max" or "min"
+  monitor: val/cer
+  save_top_k: 1
+  save_last: true
+  mode: min
   verbose: false
   dirpath: checkpoints/
   filename: "{epoch:02d}"
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 8404cd1..38b13a5 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -18,7 +18,7 @@ summary: [[1, 1, 56, 1024], [1, 89]]
 
 criterion:
   ignore_index: *ignore_index
-  # label_smoothing: 0.1
+  label_smoothing: 0.05
 
 callbacks:
   stochastic_weight_averaging:
@@ -40,30 +40,38 @@ optimizers:
 
 lr_schedulers:
   network:
-    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
-    mode: min
-    factor: 0.5
-    patience: 10
-    threshold: 1.0e-4
-    threshold_mode: rel
-    cooldown: 0
-    min_lr: 1.0e-5
-    eps: 1.0e-8
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 3.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 1284
+    pct_start: 0.3
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25.0
+    final_div_factor: 10000.0
+    three_phase: true
+    last_epoch: -1
     verbose: false
-    interval: epoch
-    monitor: val/loss
+    interval: step
+    monitor: val/cer
 
 datamodule:
-  batch_size: 16
+  batch_size: 8
+  train_fraction: 0.9
 
 network:
   input_dims: [1, 1, 56, 1024]
   num_classes: *num_classes
   pad_index: *ignore_index
+  encoder:
+    depth: 5
   decoder:
-    depth: 10
+    depth: 6
   pixel_embedding:
-    shape: [7, 128]
+    shape: [3, 64]
 
 model:
   max_output_len: *max_output_len
@@ -71,3 +79,4 @@ model:
 trainer:
   gradient_clip_val: 0.5
   max_epochs: *epochs
+  accumulate_grad_batches: 1
diff --git a/training/conf/lr_schedulers/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml
index 801a01f..20eab9f 100644
--- a/training/conf/lr_schedulers/one_cycle.yaml
+++ b/training/conf/lr_schedulers/one_cycle.yaml
@@ -1,20 +1,17 @@
-one_cycle:
-  _target_: torch.optim.lr_scheduler.OneCycleLR
-  max_lr: 1.0e-3
-  total_steps: null
-  epochs: 512
-  steps_per_epoch: 4992
-  pct_start: 0.3
-  anneal_strategy: cos
-  cycle_momentum: true
-  base_momentum: 0.85
-  max_momentum: 0.95
-  div_factor: 25.0
-  final_div_factor: 10000.0
-  three_phase: true
-  last_epoch: -1
-  verbose: false
-
-  # Non-class arguments
-  interval: step
-  monitor: val/loss
+_target_: torch.optim.lr_scheduler.OneCycleLR
+max_lr: 1.0e-3
+total_steps: null
+epochs: 512
+steps_per_epoch: 4992
+pct_start: 0.3
+anneal_strategy: cos
+cycle_momentum: true
+base_momentum: 0.85
+max_momentum: 0.95
+div_factor: 25.0
+final_div_factor: 10000.0
+three_phase: true
+last_epoch: -1
+verbose: false
+interval: step
+monitor: val/loss
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index c1491ec..b795078 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -5,4 +5,4 @@ end_token: <e>
 pad_token: <p>
 mapping:
   _target_: text_recognizer.data.mappings.EmnistMapping
-  extra_symbols: ["\n"]
+  # extra_symbols: ["\n"]
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 54eb028..39c5c46 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -10,7 +10,7 @@ encoder:
   bn_momentum: 0.99
   bn_eps: 1.0e-3
   depth: 3
-  out_channels: 128
+  out_channels: *hidden_dim
 decoder:
   _target_: text_recognizer.networks.transformer.Decoder
   depth: 6
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
deleted file mode 100644
index 4588ee9..0000000
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-_target_: text_recognizer.networks.transformer.decoder.Decoder
-depth: 4
-block:
-  _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
-  self_attn:
-    _target_: text_recognizer.networks.transformer.attention.Attention
-    dim: 64
-    num_heads: 4
-    dim_head: 64
-    dropout_rate: 0.05
-    causal: true
-    rotary_embedding:
-      _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
-      dim: 128
-  cross_attn:
-    _target_: text_recognizer.networks.transformer.attention.Attention
-    dim: 64
-    num_heads: 4
-    dim_head: 64
-    dropout_rate: 0.05
-    causal: false
-  norm:
-    _target_: text_recognizer.networks.transformer.norm.RMSNorm
-    normalized_shape: 192
-  ff:
-    _target_: text_recognizer.networks.transformer.mlp.FeedForward
-    dim_out: null
-    expansion_factor: 4
-    glu: true
-    dropout_rate: 0.2
diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml
deleted file mode 100644
index a7be069..0000000
--- a/training/conf/network/encoder/efficientnet.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-_target_: text_recognizer.networks.efficientnet.EfficientNet
-arch: b0
-stochastic_dropout_rate: 0.2
-bn_momentum: 0.99
-bn_eps: 1.0e-3
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index d4ffcdc..c2d0d62 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -13,5 +13,5 @@ limit_train_batches: 1.0
 limit_val_batches: 1.0
 limit_test_batches: 1.0
 resume_from_checkpoint: null
-accumulate_grad_batches: 2
+accumulate_grad_batches: 1
 overfit_batches: 0
-- 
cgit v1.2.3-70-g09d2