From 70540bf897df1d60375ea220cfab838cbd28c47f Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 5 Nov 2021 19:27:50 +0100
Subject: Update lines config

---
 .../conf/experiment/conv_transformer_lines.yaml    | 51 ++++++++--------------
 1 file changed, 18 insertions(+), 33 deletions(-)

(limited to 'training/conf')

diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index d2a666f..6ba4535 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -2,7 +2,7 @@
 
 defaults:
   - override /mapping: null
-  - override /criterion: null
+  - override /criterion: cross_entropy
   - override /callbacks: htr
   - override /datamodule: iam_lines
   - override /network: null
@@ -10,21 +10,18 @@ defaults:
   - override /lr_schedulers: null
   - override /optimizers: null
 
-epochs: &epochs 512
+epochs: &epochs 256
 ignore_index: &ignore_index 3
-num_classes: &num_classes 58
+num_classes: &num_classes 57
 max_output_len: &max_output_len 89
 summary: [[1, 1, 56, 1024], [1, 89]]
 
 criterion:
-  _target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
-  smoothing: 0.1 
   ignore_index: *ignore_index
     
 mapping: &mapping
   mapping:
     _target_: text_recognizer.data.mappings.emnist.EmnistMapping
-    # extra_symbols: [ "\n" ]
 
 callbacks:
   stochastic_weight_averaging:
@@ -38,31 +35,20 @@ callbacks:
 optimizers:
   madgrad:
     _target_: madgrad.MADGRAD
-    lr: 1.0e-4
+    lr: 3.0e-4
     momentum: 0.9
-    weight_decay: 5.0e-6
+    weight_decay: 0
     eps: 1.0e-6
     parameters: network
 
 lr_schedulers:
   network:
-    _target_: torch.optim.lr_scheduler.OneCycleLR
-    max_lr: 1.0e-4
-    total_steps: null
-    epochs: *epochs
-    steps_per_epoch: 722
-    pct_start: 0.01
-    anneal_strategy: cos
-    cycle_momentum: true
-    base_momentum: 0.85
-    max_momentum: 0.95
-    div_factor: 25
-    final_div_factor: 1.0e2
-    three_phase: false
-    last_epoch: -1
-    verbose: false
-    interval: step
-    monitor: val/loss
+      _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+      T_max: 256
+      eta_min: 1.0e-5
+      last_epoch: -1
+      interval: epoch
+      monitor: val/loss
 
 datamodule:
   batch_size: 32
@@ -77,26 +63,25 @@ rotary_embedding: &rotary_embedding
     dim: 64
 
 attn: &attn
-  dim: 192
-  num_heads: 4
+  dim: &hidden_dim 256
+  num_heads: 8
   dim_head: 64
-  dropout_rate: 0.05
+  dropout_rate: &dropout_rate 0.5
 
 network:
   _target_: text_recognizer.networks.conv_transformer.ConvTransformer
   input_dims: [1, 56, 1024]
-  hidden_dim: &hidden_dim 192
+  hidden_dim: *hidden_dim
   num_classes: *num_classes
   pad_index: *ignore_index
   encoder:
     _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
-    arch: b0
-    out_channels: 1280
+    arch: b2
     stochastic_dropout_rate: 0.2
     bn_momentum: 0.99
     bn_eps: 1.0e-3
   decoder:
-    depth: 4
+    depth: 6
     _target_: text_recognizer.networks.transformer.layers.Decoder
     self_attn:
       _target_: text_recognizer.networks.transformer.attention.Attention
@@ -116,7 +101,7 @@ network:
       dim_out: null
       expansion_factor: 4
       glu: true
-      dropout_rate: 0.05
+      dropout_rate: *dropout_rate
     pre_norm: true
   pixel_pos_embedding:
     _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
-- 
cgit v1.2.3-70-g09d2