From a65d3ec18a5541cec5297769f1027422975a62bc Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 3 Sep 2023 01:13:37 +0200
Subject: Update confs and callbacks

---
 .../conf/experiment/conv_transformer_lines.yaml    | 134 ---------------------
 .../experiment/conv_transformer_paragraphs.yaml    | 132 --------------------
 training/conf/experiment/vit_lines.yaml            |  45 +------
 3 files changed, 3 insertions(+), 308 deletions(-)
 delete mode 100644 training/conf/experiment/conv_transformer_lines.yaml
 delete mode 100644 training/conf/experiment/conv_transformer_paragraphs.yaml

(limited to 'training/conf/experiment')

diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
deleted file mode 100644
index 12fe701..0000000
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ /dev/null
@@ -1,134 +0,0 @@
-# @package _global_
-
-defaults:
-  - override /criterion: cross_entropy
-  - override /callbacks: htr
-  - override /datamodule: iam_lines
-  - override /network: null
-  - override /model: lit_transformer
-  - override /lr_scheduler: null
-  - override /optimizer: null
-
-tags: [lines]
-epochs: &epochs 64
-ignore_index: &ignore_index 3
-# summary: [[1, 1, 56, 1024], [1, 89]]
-
-logger:
-  wandb:
-    tags: ${tags}
-
-criterion:
-  ignore_index: *ignore_index
-  # label_smoothing: 0.05
-
-callbacks:
-  stochastic_weight_averaging:
-    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
-    swa_epoch_start: 0.75
-    swa_lrs: 1.0e-5
-    annealing_epochs: 10
-    annealing_strategy: cos
-    device: null
-
-optimizer:
-  _target_: adan_pytorch.Adan
-  lr: 3.0e-4
-  betas: [0.02, 0.08, 0.01]
-  weight_decay: 0.02
-
-lr_scheduler:
-  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
-  mode: min
-  factor: 0.8
-  patience: 10
-  threshold: 1.0e-4
-  threshold_mode: rel
-  cooldown: 0
-  min_lr: 1.0e-5
-  eps: 1.0e-8
-  verbose: false
-  interval: epoch
-  monitor: val/cer
-
-datamodule:
-  batch_size: 16
-  train_fraction: 0.95
-
-network:
-  _target_: text_recognizer.network.ConvTransformer
-  encoder:
-    _target_: text_recognizer.network.image_encoder.ImageEncoder
-    encoder:
-      _target_: text_recognizer.network.convnext.ConvNext
-      dim: 16
-      dim_mults: [2, 4, 32]
-      depths: [3, 3, 6]
-      downsampling_factors: [[2, 2], [2, 2], [2, 2]]
-      attn:
-        _target_: text_recognizer.network.convnext.TransformerBlock
-        attn:
-          _target_: text_recognizer.network.convnext.Attention
-          dim: &dim 512
-          heads: 4
-          dim_head: 64
-          scale: 8
-        ff:
-          _target_: text_recognizer.network.convnext.FeedForward
-          dim: *dim
-          mult: 2
-    pixel_embedding:
-      _target_: "text_recognizer.network.transformer.embeddings.axial.\
-        AxialPositionalEmbeddingImage"
-      dim: *dim
-      axial_shape: [7, 128]
-  decoder:
-    _target_: text_recognizer.network.text_decoder.TextDecoder
-    dim: *dim
-    num_classes: 58
-    pad_index: *ignore_index
-    decoder:
-      _target_: text_recognizer.network.transformer.Decoder
-      dim: *dim
-      depth: 6
-      block:
-        _target_: "text_recognizer.network.transformer.decoder_block.\
-          DecoderBlock"
-        self_attn:
-          _target_: text_recognizer.network.transformer.Attention
-          dim: *dim
-          num_heads: 8
-          dim_head: &dim_head 64
-          dropout_rate: &dropout_rate 0.2
-          causal: true
-        cross_attn:
-          _target_: text_recognizer.network.transformer.Attention
-          dim: *dim
-          num_heads: 8
-          dim_head: *dim_head
-          dropout_rate: *dropout_rate
-          causal: false
-        norm:
-          _target_: text_recognizer.network.transformer.RMSNorm
-          dim: *dim
-        ff:
-          _target_: text_recognizer.network.transformer.FeedForward
-          dim: *dim
-          dim_out: null
-          expansion_factor: 2
-          glu: true
-          dropout_rate: *dropout_rate
-      rotary_embedding:
-        _target_: text_recognizer.network.transformer.RotaryEmbedding
-        dim: *dim_head
-
-model:
-  max_output_len: 89
-
-trainer:
-  gradient_clip_val: 1.0
-  max_epochs: *epochs
-  accumulate_grad_batches: 1
-  limit_train_batches: 1.0
-  limit_val_batches: 1.0
-  limit_test_batches: 1.0
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
deleted file mode 100644
index 9df2ea9..0000000
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ /dev/null
@@ -1,132 +0,0 @@
-# @package _global_
-
-defaults:
-  - override /criterion: cross_entropy
-  - override /callbacks: htr
-  - override /datamodule: iam_extended_paragraphs
-  - override /network: null
-  - override /model: lit_transformer
-  - override /lr_scheduler: null
-  - override /optimizer: null
-
-tags: [paragraphs]
-epochs: &epochs 256
-ignore_index: &ignore_index 3
-# max_output_len: &max_output_len 682
-# summary: [[1, 1, 576, 640], [1, 682]]
-
-logger:
-  wandb:
-    tags: ${tags}
-
-criterion:
-  ignore_index: *ignore_index
-  # label_smoothing: 0.05
-
-callbacks:
-  stochastic_weight_averaging:
-    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
-    swa_epoch_start: 0.75
-    swa_lrs: 1.0e-5
-    annealing_epochs: 10
-    annealing_strategy: cos
-    device: null
-
-optimizer:
-  _target_: adan_pytorch.Adan
-  lr: 3.0e-4
-  betas: [0.02, 0.08, 0.01]
-  weight_decay: 0.02
-
-lr_scheduler:
-  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
-  mode: min
-  factor: 0.8
-  patience: 10
-  threshold: 1.0e-4
-  threshold_mode: rel
-  cooldown: 0
-  min_lr: 1.0e-5
-  eps: 1.0e-8
-  verbose: false
-  interval: epoch
-  monitor: val/cer
-
-datamodule:
-  batch_size: 4
-  train_fraction: 0.95
-
-network:
-  _target_: text_recognizer.network.ConvTransformer
-  encoder:
-    _target_: text_recognizer.network.image_encoder.ImageEncoder
-    encoder:
-      _target_: text_recognizer.network.convnext.ConvNext
-      dim: 16
-      dim_mults: [1, 2, 4, 8, 32]
-      depths: [2, 3, 3, 3, 6]
-      downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
-      attn:
-        _target_: text_recognizer.network.convnext.TransformerBlock
-        attn:
-          _target_: text_recognizer.network.convnext.Attention
-          dim: &dim 512
-          heads: 4
-          dim_head: 64
-          scale: 8
-        ff:
-          _target_: text_recognizer.network.convnext.FeedForward
-          dim: *dim
-          mult: 2
-    pixel_embedding:
-      _target_: "text_recognizer.network.transformer.embeddings.axial.\
-        AxialPositionalEmbeddingImage"
-      dim: *dim
-      axial_shape: [18, 80]
-  decoder:
-    _target_: text_recognizer.network.text_decoder.TextDecoder
-    dim: *dim
-    num_classes: 58
-    pad_index: *ignore_index
-    decoder:
-      _target_: text_recognizer.network.transformer.Decoder
-      dim: *dim
-      depth: 6
-      block:
-        _target_: "text_recognizer.network.transformer.decoder_block.\
-          DecoderBlock"
-        self_attn:
-          _target_: text_recognizer.network.transformer.Attention
-          dim: *dim
-          num_heads: 8
-          dim_head: &dim_head 64
-          dropout_rate: &dropout_rate 0.2
-          causal: true
-        cross_attn:
-          _target_: text_recognizer.network.transformer.Attention
-          dim: *dim
-          num_heads: 8
-          dim_head: *dim_head
-          dropout_rate: *dropout_rate
-          causal: false
-        norm:
-          _target_: text_recognizer.network.transformer.RMSNorm
-          dim: *dim
-        ff:
-          _target_: text_recognizer.network.transformer.FeedForward
-          dim: *dim
-          dim_out: null
-          expansion_factor: 2
-          glu: true
-          dropout_rate: *dropout_rate
-      rotary_embedding:
-        _target_: text_recognizer.network.transformer.RotaryEmbedding
-        dim: *dim_head
-
-trainer:
-  gradient_clip_val: 1.0
-  max_epochs: *epochs
-  accumulate_grad_batches: 2
-  limit_train_batches: 1.0
-  limit_val_batches: 1.0
-  limit_test_batches: 1.0
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index f57eead..f3049ea 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -4,13 +4,13 @@ defaults:
   - override /criterion: cross_entropy
   - override /callbacks: htr
   - override /datamodule: iam_lines
-  - override /network: null
+  - override /network: vit_lines
   - override /model: lit_transformer
   - override /lr_scheduler: null
   - override /optimizer: null
 
 tags: [lines, vit]
-epochs: &epochs 256
+epochs: &epochs 128
 ignore_index: &ignore_index 3
 # summary: [[1, 1, 56, 1024], [1, 89]]
 
@@ -59,45 +59,6 @@ datamodule:
   batch_size: 16
   train_fraction: 0.95
 
-network:
-  _target_: text_recognizer.network.vit.VisionTransformer
-  image_height: 56
-  image_width: 1024
-  patch_height: 28
-  patch_width: 32
-  dim: &dim 1024
-  num_classes: &num_classes 58
-  encoder:
-    _target_: text_recognizer.network.transformer.encoder.Encoder
-    dim: *dim
-    inner_dim: 2048
-    heads: 16
-    dim_head: 64
-    depth: 4
-    dropout_rate: 0.0
-  decoder:
-    _target_: text_recognizer.network.transformer.decoder.Decoder
-    dim: *dim
-    inner_dim: 2048
-    heads: 16
-    dim_head: 64
-    depth: 4
-    dropout_rate: 0.0
-  token_embedding:
-    _target_: "text_recognizer.network.transformer.embedding.token.\
-      TokenEmbedding"
-    num_tokens: *num_classes
-    dim: *dim
-    use_l2: true
-  pos_embedding:
-    _target_: "text_recognizer.network.transformer.embedding.absolute.\
-      AbsolutePositionalEmbedding"
-    dim: *dim
-    max_length: 89
-    use_l2: true
-  tie_embeddings: true
-  pad_index: 3
-
 model:
   max_output_len: 89
 
@@ -105,7 +66,7 @@ trainer:
   fast_dev_run: false
   gradient_clip_val: 1.0
   max_epochs: *epochs
-  accumulate_grad_batches: 4
+  accumulate_grad_batches: 1
   limit_train_batches: 1.0
   limit_val_batches: 1.0
   limit_test_batches: 1.0
-- 
cgit v1.2.3-70-g09d2