From becc8e1380a36f45a8dadf5a7cc6c7b922fe8dff Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 3 Oct 2021 00:33:33 +0200
Subject: Add experiment configs

---
 training/conf/callbacks/wandb_htr_predictions.yaml |   2 +-
 training/conf/experiment/cnn_htr_char_lines.yaml   |  90 +++++++-----
 training/conf/experiment/cnn_htr_wp_lines.yaml     | 157 +++++++++++++++++++++
 .../experiment/cnn_transformer_paragraphs.yaml     | 148 +++++++++++++++++++
 training/conf/model/lit_transformer.yaml           |   2 +-
 5 files changed, 361 insertions(+), 38 deletions(-)
 create mode 100644 training/conf/experiment/cnn_htr_wp_lines.yaml
 create mode 100644 training/conf/experiment/cnn_transformer_paragraphs.yaml

diff --git a/training/conf/callbacks/wandb_htr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml
index 589e7e0..468b6e0 100644
--- a/training/conf/callbacks/wandb_htr_predictions.yaml
+++ b/training/conf/callbacks/wandb_htr_predictions.yaml
@@ -1,4 +1,4 @@
 log_text_predictions:
   _target_: callbacks.wandb_callbacks.LogTextPredictions
   num_samples: 8
-  log_train: true
+  log_train: false
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml
index 0f28ff9..0d62a73 100644
--- a/training/conf/experiment/cnn_htr_char_lines.yaml
+++ b/training/conf/experiment/cnn_htr_char_lines.yaml
@@ -10,28 +10,26 @@ defaults:
   - override /optimizers: null
 
 
+epochs: &epochs 200
+ignore_index: &ignore_index 3
+num_classes: &num_classes 58
+max_output_len: &max_output_len 89
+
 criterion:
   _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
   smoothing: 0.1 
-  ignore_index: 1000
+  ignore_index: *ignore_index
+  # _target_: torch.nn.CrossEntropyLoss
+  # ignore_index: *ignore_index
     
 mapping:
-  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
-  num_features: 1000
-  tokens: iamdb_1kwp_tokens_1000.txt
-  lexicon: iamdb_1kwp_lex_1000.txt
-  data_dir: null
-  use_words: false
-  prepend_wordsep: false
-  special_tokens: [ <s>, <e>, <p> ]
-  # _target_: text_recognizer.data.emnist_mapping.EmnistMapping
-  # extra_symbols: [ "\n" ]
+  _target_: text_recognizer.data.emnist_mapping.EmnistMapping
 
 callbacks:
   stochastic_weight_averaging:
     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
-    swa_epoch_start: 0.8
-    swa_lrs: 0.05
+    swa_epoch_start: 0.75
+    swa_lrs: 5.0e-5
     annealing_epochs: 10
     annealing_strategy: cos
     device: null
@@ -39,7 +37,7 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 1.0e-4
+    lr: 3.0e-4
     momentum: 0.9
     weight_decay: 0
     eps: 1.0e-6
@@ -48,34 +46,42 @@ optimizers:
 
 lr_schedulers:
   network:
-    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
-    mode: min
-    factor: 0.1
-    patience: 10
-    threshold: 1.0e-4
-    threshold_mode: rel
-    cooldown: 0
-    min_lr: 1.0e-7
-    eps: 1.0e-8
-    interval: epoch
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 3.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 90
+    pct_start: 0.1
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e4
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    # Non-class arguments
+    interval: step
     monitor: val/loss
 
 datamodule:
   _target_: text_recognizer.data.iam_lines.IAMLines
-  batch_size: 16
+  batch_size: 32
   num_workers: 12
   train_fraction: 0.8
   augment: true
-  pin_memory: false
+  pin_memory: true
+  word_pieces: false
 
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 56, 1024]
-  hidden_dim: 128
+  hidden_dim: &hidden_dim 128
   encoder_dim: 1280
   dropout_rate: 0.2
-  num_classes: 1006
-  pad_index: 1000
+  num_classes: *num_classes
+  pad_index: *ignore_index
   encoder:
     _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
     arch: b0
@@ -85,14 +91,14 @@ network:
     bn_eps: 1.0e-3
   decoder:
     _target_: text_recognizer.networks.transformer.Decoder
-    dim: 128
+    dim: *hidden_dim
     depth: 3 
     num_heads: 4
     attn_fn: text_recognizer.networks.transformer.attention.Attention
     attn_kwargs:
       dim_head: 32
       dropout_rate: 0.2
-    norm_fn: torch.nn.LayerNorm
+    norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
     ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
     ff_kwargs:
       dim_out: null
@@ -101,11 +107,23 @@ network:
       dropout_rate: 0.2
     cross_attend: true
     pre_norm: true
-    rotary_emb: null
+    rotary_emb:
+      _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+      dim: 32
+  pixel_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
+    hidden_dim: *hidden_dim 
+    max_h: 1
+    max_w: 32
+  token_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
+    hidden_dim: *hidden_dim 
+    dropout_rate: 0.2
+    max_len: *max_output_len
 
 model:
   _target_: text_recognizer.models.transformer.TransformerLitModel
-  max_output_len: 89
+  max_output_len: *max_output_len
   start_token: <s>
   end_token: <e>
   pad_token: <p>
@@ -115,11 +133,11 @@ trainer:
   stochastic_weight_avg: true
   auto_scale_batch_size: binsearch
   auto_lr_find: false
-  gradient_clip_val: 0
+  gradient_clip_val: 0.5
   fast_dev_run: false
   gpus: 1
   precision: 16
-  max_epochs: 1024
+  max_epochs: *epochs
   terminate_on_nan: true
   weights_summary: null
   limit_train_batches: 1.0 
@@ -127,6 +145,6 @@ trainer:
   limit_test_batches: 1.0
   resume_from_checkpoint: null
   accumulate_grad_batches: 4
-  overfit_batches: 0.0
+  overfit_batches: 0
 
 summary: [[1, 1, 56, 1024], [1, 89]]
diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml
new file mode 100644
index 0000000..79075cd
--- /dev/null
+++ b/training/conf/experiment/cnn_htr_wp_lines.yaml
@@ -0,0 +1,157 @@
+# @package _global_
+
+defaults:
+  - override /mapping: null
+  - override /criterion: null
+  - override /datamodule: null
+  - override /network: null
+  - override /model: null
+  - override /lr_schedulers: null
+  - override /optimizers: null
+
+epochs: &epochs 256
+ignore_index: &ignore_index 1000
+num_classes: &num_classes 1006
+max_output_len: &max_output_len 72
+
+criterion:
+  _target_: torch.nn.CrossEntropyLoss
+  ignore_index: *ignore_index
+  # _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+  # smoothing: 0.1 
+  # ignore_index: *ignore_index
+    
+mapping:
+  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
+  num_features: 1000
+  tokens: iamdb_1kwp_tokens_1000.txt
+  lexicon: iamdb_1kwp_lex_1000.txt
+  data_dir: null
+  use_words: false
+  prepend_wordsep: false
+  special_tokens: [ <s>, <e>, <p> ]
+  # extra_symbols: [ "\n" ]
+
+callbacks:
+  stochastic_weight_averaging:
+    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+    swa_epoch_start: 0.75
+    swa_lrs: 1.0e-5
+    annealing_epochs: 10
+    annealing_strategy: cos
+    device: null
+
+optimizers:
+  madgrad:
+    _target_: madgrad.MADGRAD
+    lr: 3.0e-4
+    momentum: 0.9
+    weight_decay: 0
+    eps: 1.0e-6
+
+    parameters: network
+
+lr_schedulers:
+  network:
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 3.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 90
+    pct_start: 0.1
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e4
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    # Non-class arguments
+    interval: step
+    monitor: val/loss
+
+datamodule:
+  _target_: text_recognizer.data.iam_lines.IAMLines
+  batch_size: 32
+  num_workers: 12
+  train_fraction: 0.8
+  augment: true
+  pin_memory: true
+  word_pieces: true
+
+network:
+  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+  input_dims: [1, 56, 1024]
+  hidden_dim: &hidden_dim 128
+  encoder_dim: 1280
+  dropout_rate: 0.2
+  num_classes: *num_classes
+  pad_index: *ignore_index
+  encoder:
+    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+    arch: b0
+    out_channels: 1280
+    stochastic_dropout_rate: 0.2
+    bn_momentum: 0.99
+    bn_eps: 1.0e-3
+  decoder:
+    _target_: text_recognizer.networks.transformer.Decoder
+    dim: *hidden_dim
+    depth: 3 
+    num_heads: 4
+    attn_fn: text_recognizer.networks.transformer.attention.Attention
+    attn_kwargs:
+      dim_head: 32
+      dropout_rate: 0.2
+    norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
+    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+    ff_kwargs:
+      dim_out: null
+      expansion_factor: 4
+      glu: true
+      dropout_rate: 0.2
+    cross_attend: true
+    pre_norm: true
+    rotary_emb:
+      _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+      dim: 32
+  pixel_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
+    hidden_dim: *hidden_dim 
+    max_h: 1
+    max_w: 32
+  token_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
+    hidden_dim: *hidden_dim 
+    dropout_rate: 0.2
+    max_len: *max_output_len
+
+model:
+  _target_: text_recognizer.models.transformer.TransformerLitModel
+  max_output_len: *max_output_len
+  start_token: <s>
+  end_token: <e>
+  pad_token: <p>
+
+trainer:
+  _target_: pytorch_lightning.Trainer
+  stochastic_weight_avg: true
+  auto_scale_batch_size: binsearch
+  auto_lr_find: false
+  gradient_clip_val: 0.5
+  fast_dev_run: false
+  gpus: 1
+  precision: 16
+  max_epochs: *epochs
+  terminate_on_nan: true
+  weights_summary: null
+  limit_train_batches: 1.0 
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
+  resume_from_checkpoint: null
+  accumulate_grad_batches: 4
+  overfit_batches: 0
+
+summary: [[1, 1, 56, 1024], [1, 89]]
diff --git a/training/conf/experiment/cnn_transformer_paragraphs.yaml b/training/conf/experiment/cnn_transformer_paragraphs.yaml
new file mode 100644
index 0000000..b415c29
--- /dev/null
+++ b/training/conf/experiment/cnn_transformer_paragraphs.yaml
@@ -0,0 +1,148 @@
+# @package _global_
+
+defaults:
+  - override /mapping: null
+  - override /criterion: null
+  - override /datamodule: null
+  - override /network: null
+  - override /model: null
+  - override /lr_schedulers: null
+  - override /optimizers: null
+
+
+epochs: &epochs 512
+ignore_index: &ignore_index 3
+num_classes: &num_classes 58
+max_output_len: &max_output_len 682
+summary: [[1, 1, 576, 640], [1, 682]]
+
+criterion:
+  _target_: torch.nn.CrossEntropyLoss
+  ignore_index: *ignore_index
+    
+mapping:
+  _target_: text_recognizer.data.emnist_mapping.EmnistMapping
+  extra_symbols: [ "\n" ]
+
+callbacks:
+  stochastic_weight_averaging:
+    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+    swa_epoch_start: 0.75
+    swa_lrs: 1.0e-5
+    annealing_epochs: 10
+    annealing_strategy: cos
+    device: null
+
+optimizers:
+  madgrad:
+    _target_: madgrad.MADGRAD
+    lr: 3.0e-4
+    momentum: 0.9
+    weight_decay: 0
+    eps: 1.0e-6
+
+    parameters: network
+
+lr_schedulers:
+  network:
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 3.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 52
+    pct_start: 0.1
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e4
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    # Non-class arguments
+    interval: step
+    monitor: val/loss
+
+datamodule:
+  _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
+  batch_size: 4
+  num_workers: 12
+  train_fraction: 0.8
+  augment: true
+  pin_memory: false
+  word_pieces: false
+  resize: null
+
+network:
+  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+  input_dims: [1, 56, 1024]
+  hidden_dim: &hidden_dim 128
+  encoder_dim: 1280
+  dropout_rate: 0.2
+  num_classes: *num_classes
+  pad_index: *ignore_index
+  encoder:
+    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+    arch: b0
+    out_channels: 1280
+    stochastic_dropout_rate: 0.2
+    bn_momentum: 0.99
+    bn_eps: 1.0e-3
+  decoder:
+    _target_: text_recognizer.networks.transformer.Decoder
+    dim: *hidden_dim
+    depth: 3 
+    num_heads: 4
+    attn_fn: text_recognizer.networks.transformer.attention.Attention
+    attn_kwargs:
+      dim_head: 32
+      dropout_rate: 0.2
+    norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
+    ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+    ff_kwargs:
+      dim_out: null
+      expansion_factor: 4
+      glu: true
+      dropout_rate: 0.2
+    cross_attend: true
+    pre_norm: true
+    rotary_emb:
+      _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+      dim: 32
+  pixel_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
+    hidden_dim: *hidden_dim 
+    max_h: 18
+    max_w: 20
+  token_pos_embedding:
+    _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
+    hidden_dim: *hidden_dim 
+    dropout_rate: 0.2
+    max_len: *max_output_len
+
+model:
+  _target_: text_recognizer.models.transformer.TransformerLitModel
+  max_output_len: *max_output_len
+  start_token: <s>
+  end_token: <e>
+  pad_token: <p>
+
+trainer:
+  _target_: pytorch_lightning.Trainer
+  stochastic_weight_avg: true
+  auto_scale_batch_size: binsearch
+  auto_lr_find: false
+  gradient_clip_val: 0.5
+  fast_dev_run: false
+  gpus: 1
+  precision: 16
+  max_epochs: *epochs
+  terminate_on_nan: true
+  weights_summary: null
+  limit_train_batches: 1.0 
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
+  resume_from_checkpoint: null
+  accumulate_grad_batches: 32
+  overfit_batches: 0
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 0ec3b8a..5172533 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,5 +1,5 @@
 _target_: text_recognizer.models.transformer.TransformerLitModel
-max_output_len: 451
+max_output_len: 682
 start_token: <s>
 end_token: <e>
 pad_token: <p>
-- 
cgit v1.2.3-70-g09d2