From 5a9448c0a975b3e99b5456a9d991b8232b019aed Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Fri, 11 Feb 2022 22:49:09 +0100
Subject: feat: update lines experiment

---
 .../conf/experiment/conv_transformer_lines.yaml    | 108 ++++++++++++---------
 1 file changed, 60 insertions(+), 48 deletions(-)

(limited to 'training/conf')

diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index cca452d..259e4ea 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -1,3 +1,4 @@
+---
 # @package _global_
 
 defaults:
@@ -18,8 +19,8 @@ summary: [[1, 1, 56, 1024], [1, 89]]
 
 criterion:
   ignore_index: *ignore_index
-  label_smoothing: 0.1
-    
+  # label_smoothing: 0.1
+
 mapping: &mapping
   mapping:
     _target_: text_recognizer.data.mappings.emnist.EmnistMapping
@@ -34,20 +35,26 @@ callbacks:
     device: null
 
 optimizers:
-  madgrad:
-    _target_: madgrad.MADGRAD
-    lr: 1.0e-4
-    momentum: 0.9
+  radam:
+    _target_: torch.optim.RAdam
+    lr: 3.0e-4
+    betas: [0.9, 0.999]
     weight_decay: 0
-    eps: 1.0e-6
+    eps: 1.0e-8
     parameters: network
 
 lr_schedulers:
   network:
-    _target_: torch.optim.lr_scheduler.CosineAnnealingLR
-    T_max: *epochs
-    eta_min: 1.0e-6
-    last_epoch: -1
+    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+    mode: min
+    factor: 0.5
+    patience: 10
+    threshold: 1.0e-4
+    threshold_mode: rel
+    cooldown: 0
+    min_lr: 1.0e-5
+    eps: 1.0e-8
+    verbose: false
     interval: epoch
     monitor: val/loss
 
@@ -58,33 +65,32 @@ datamodule:
   pin_memory: true
   << : *mapping
 
+encoder: &encoder
+  _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
+  arch: b0
+  stochastic_dropout_rate: 0.2
+  bn_momentum: 0.99
+  bn_eps: 1.0e-3
+  depth: 5
+
 rotary_embedding: &rotary_embedding
-  rotary_embedding: 
-    _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+  rotary_embedding:
+    _target_: >
+      text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
     dim: 64
 
 attn: &attn
-  dim: &hidden_dim 128
+  dim: &hidden_dim 512
   num_heads: 4
   dim_head: 64
-  dropout_rate: &dropout_rate 0.2
+  dropout_rate: &dropout_rate 0.4
 
-network:
-  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
-  input_dims: [1, 56, 1024]
-  hidden_dim: *hidden_dim
-  num_classes: *num_classes
-  pad_index: *ignore_index
-  encoder:
-    _target_: text_recognizer.networks.efficientnet.efficientnet.EfficientNet
-    arch: b0
-    depth: 5
-    stochastic_dropout_rate: 0.2
-    bn_momentum: 0.99
-    bn_eps: 1.0e-3
-  decoder:
-    depth: 3
-    _target_: text_recognizer.networks.transformer.layers.Decoder
+decoder: &decoder
+  _target_: text_recognizer.networks.transformer.decoder.Decoder
+  depth: 6
+  has_pos_emb: true
+  block:
+    _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
     self_attn:
       _target_: text_recognizer.networks.transformer.attention.Attention
       << : *attn
@@ -95,28 +101,34 @@ network:
       << : *attn
       causal: false
     norm:
-      _target_: text_recognizer.networks.transformer.norm.ScaleNorm
-      normalized_shape: *hidden_dim
-    ff: 
+      _target_: text_recognizer.networks.transformer.norm.RMSNorm
+      dim: *hidden_dim
+    ff:
       _target_: text_recognizer.networks.transformer.mlp.FeedForward
       dim: *hidden_dim
       dim_out: null
-      expansion_factor: 4
+      expansion_factor: 2
       glu: true
       dropout_rate: *dropout_rate
-    pre_norm: true
+
+pixel_pos_embedding: &pixel_pos_embedding
+  _target_: >
+    text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+  dim: *hidden_dim
+  shape: &shape [3, 64]
+
+network:
+  _target_: text_recognizer.networks.conv_transformer.ConvTransformer
+  input_dims: [1, 1, 56, 1024]
+  hidden_dim: *hidden_dim
+  num_classes: *num_classes
+  pad_index: *ignore_index
+  encoder:
+    << : *encoder
+  decoder:
+    << : *decoder
   pixel_pos_embedding:
-    _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
-    dim: *hidden_dim
-    shape: &shape [3, 64]
-  axial_encoder: null
-    # _target_: text_recognizer.networks.transformer.axial_attention.encoder.AxialEncoder
-    # dim: *hidden_dim
-    # heads: 4
-    # shape: *shape
-    # depth: 2
-    # dim_head: 64
-    # dim_index: 1
+    << : *pixel_pos_embedding
 
 model:
   _target_: text_recognizer.models.transformer.TransformerLitModel
@@ -138,7 +150,7 @@ trainer:
   max_epochs: *epochs
   terminate_on_nan: true
   weights_summary: null
-  limit_train_batches: 1.0 
+  limit_train_batches: 1.0
   limit_val_batches: 1.0
   limit_test_batches: 1.0
   resume_from_checkpoint: null
-- 
cgit v1.2.3-70-g09d2