From 0879c06112e11c2091c223575573e40f086e38d5 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Wed, 3 Nov 2021 22:18:13 +0100
Subject: Update configs

---
 .../conf/experiment/conv_transformer_lines.yaml    | 151 +++++++++++++++++++++
 .../experiment/conv_transformer_paragraphs.yaml    |  13 +-
 2 files changed, 157 insertions(+), 7 deletions(-)
 create mode 100644 training/conf/experiment/conv_transformer_lines.yaml

(limited to 'training/conf/experiment')

diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
new file mode 100644
index 0000000..d2a666f
--- /dev/null
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -0,0 +1,151 @@
+# @package _global_
+
+defaults:
+  - override /mapping: null
+  - override /criterion: null
+  - override /callbacks: htr
+  - override /datamodule: iam_lines
+  - override /network: null
+  - override /model: null
+  - override /lr_schedulers: null
+  - override /optimizers: null
+
+epochs: &epochs 512
+ignore_index: &ignore_index 3
+num_classes: &num_classes 58
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024], [1, 89]]
+
+criterion:
+  _target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
+  smoothing: 0.1 
+  ignore_index: *ignore_index
+    
+mapping: &mapping
+  mapping:
+    _target_: text_recognizer.data.mappings.emnist.EmnistMapping
+    # extra_symbols: [ "\n" ]
+
+callbacks:
+  stochastic_weight_averaging:
+    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+    swa_epoch_start: 0.75
+    swa_lrs: 1.0e-5
+    annealing_epochs: 10
+    annealing_strategy: cos
+    device: null
+
+optimizers:
+  madgrad:
+    _target_: madgrad.MADGRAD
+    lr: 1.0e-4
+    momentum: 0.9
+    weight_decay: 5.0e-6
+    eps: 1.0e-6
+    parameters: network
+
+lr_schedulers:
+  network:
+    _target_: torch.optim.lr_scheduler.OneCycleLR
+    max_lr: 1.0e-4
+    total_steps: null
+    epochs: *epochs
+    steps_per_epoch: 722
+    pct_start: 0.01
+    anneal_strategy: cos
+    cycle_momentum: true
+    base_momentum: 0.85
+    max_momentum: 0.95
+    div_factor: 25
+    final_div_factor: 1.0e2
+    three_phase: false
+    last_epoch: -1
+    verbose: false
+    interval: step
+    monitor: val/loss
+
+datamodule:
+  batch_size: 32
+  num_workers: 12
+  train_fraction: 0.9
+  pin_memory: true
+  << : *mapping
+
+rotary_embedding: &rotary_embedding
+  rotary_embedding: 
+    _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+    dim: 64
+
+attn: &attn
+  dim: 192
+  num_heads: 4
+  dim_head: 64
+  dropout_rate: 0.05
+
+network:
+  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+  input_dims: [1, 56, 1024]
+  hidden_dim: &hidden_dim 192
+  num_classes: *num_classes
+  pad_index: *ignore_index
+  encoder:
+    _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+    arch: b0
+    out_channels: 1280
+    stochastic_dropout_rate: 0.2
+    bn_momentum: 0.99
+    bn_eps: 1.0e-3
+  decoder:
+    depth: 4
+    _target_: text_recognizer.networks.transformer.layers.Decoder
+    self_attn:
+      _target_: text_recognizer.networks.transformer.attention.Attention
+      << : *attn
+      causal: true
+      << : *rotary_embedding
+    cross_attn:
+      _target_: text_recognizer.networks.transformer.attention.Attention
+      << : *attn
+      causal: false
+    norm:
+      _target_: text_recognizer.networks.transformer.norm.ScaleNorm
+      normalized_shape: *hidden_dim
+    ff: 
+      _target_: text_recognizer.networks.transformer.mlp.FeedForward
+      dim: *hidden_dim
+      dim_out: null
+      expansion_factor: 4
+      glu: true
+      dropout_rate: 0.05
+    pre_norm: true
+  pixel_pos_embedding:
+    _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+    dim: *hidden_dim
+    shape: [1, 32]
+
+model:
+  _target_: text_recognizer.models.transformer.TransformerLitModel
+  << : *mapping
+  max_output_len: *max_output_len
+  start_token: <s>
+  end_token: <e>
+  pad_token: <p>
+
+trainer:
+  _target_: pytorch_lightning.Trainer
+  stochastic_weight_avg: true
+  auto_scale_batch_size: binsearch
+  auto_lr_find: false
+  gradient_clip_val: 0.5
+  fast_dev_run: false
+  gpus: 1
+  precision: 16
+  max_epochs: *epochs
+  terminate_on_nan: true
+  weights_summary: null
+  limit_train_batches: 1.0 
+  limit_val_batches: 1.0
+  limit_test_batches: 1.0
+  resume_from_checkpoint: null
+  accumulate_grad_batches: 1
+  overfit_batches: 0
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 5fb7377..e958367 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -47,11 +47,11 @@ optimizers:
 lr_schedulers:
   network:
     _target_: torch.optim.lr_scheduler.OneCycleLR
-    max_lr: 1.5e-4
+    max_lr: 1.0e-4
     total_steps: null
     epochs: *epochs
     steps_per_epoch: 722
-    pct_start: 0.03
+    pct_start: 0.01
     anneal_strategy: cos
     cycle_momentum: true
     base_momentum: 0.85
@@ -87,8 +87,6 @@ network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 576, 640]
   hidden_dim: &hidden_dim 192
-  encoder_dim: 1280
-  dropout_rate: 0.05
   num_classes: *num_classes
   pad_index: *ignore_index
   encoder:
@@ -99,7 +97,7 @@ network:
     bn_momentum: 0.99
     bn_eps: 1.0e-3
   decoder:
-    depth: 4
+    depth: 3
     local_depth: 2
     _target_: text_recognizer.networks.transformer.layers.Decoder
     self_attn:
@@ -114,8 +112,9 @@ network:
     local_self_attn:
       _target_: text_recognizer.networks.transformer.local_attention.LocalAttention
       << : *attn
-      window_size: 11
-      look_back: 2
+      window_size: 31
+      look_back: 1
+      autopad: true
       << : *rotary_embedding
     norm:
       _target_: text_recognizer.networks.transformer.norm.ScaleNorm
-- 
cgit v1.2.3-70-g09d2