summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
commita65d3ec18a5541cec5297769f1027422975a62bc (patch)
tree08e5e22f76db2449d265476f5fb42c5ea64a2776 /training/conf/experiment
parente4d618443808f0931bbef0b9e10a2c2a215281a5 (diff)
Update confs and callbacks
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml134
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml132
-rw-r--r--training/conf/experiment/vit_lines.yaml45
3 files changed, 3 insertions, 308 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
deleted file mode 100644
index 12fe701..0000000
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ /dev/null
@@ -1,134 +0,0 @@
-# @package _global_
-
-defaults:
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_lines
- - override /network: null
- - override /model: lit_transformer
- - override /lr_scheduler: null
- - override /optimizer: null
-
-tags: [lines]
-epochs: &epochs 64
-ignore_index: &ignore_index 3
-# summary: [[1, 1, 56, 1024], [1, 89]]
-
-logger:
- wandb:
- tags: ${tags}
-
-criterion:
- ignore_index: *ignore_index
- # label_smoothing: 0.05
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizer:
- _target_: adan_pytorch.Adan
- lr: 3.0e-4
- betas: [0.02, 0.08, 0.01]
- weight_decay: 0.02
-
-lr_scheduler:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.8
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/cer
-
-datamodule:
- batch_size: 16
- train_fraction: 0.95
-
-network:
- _target_: text_recognizer.network.ConvTransformer
- encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 32]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
- attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: &dim 512
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: *dim
- mult: 2
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *dim
- axial_shape: [7, 128]
- decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- dim: *dim
- num_classes: 58
- pad_index: *ignore_index
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *dim
- depth: 6
- block:
- _target_: "text_recognizer.network.transformer.decoder_block.\
- DecoderBlock"
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: &dim_head 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: *dim_head
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: *dim_head
-
-model:
- max_output_len: 89
-
-trainer:
- gradient_clip_val: 1.0
- max_epochs: *epochs
- accumulate_grad_batches: 1
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
deleted file mode 100644
index 9df2ea9..0000000
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ /dev/null
@@ -1,132 +0,0 @@
-# @package _global_
-
-defaults:
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_extended_paragraphs
- - override /network: null
- - override /model: lit_transformer
- - override /lr_scheduler: null
- - override /optimizer: null
-
-tags: [paragraphs]
-epochs: &epochs 256
-ignore_index: &ignore_index 3
-# max_output_len: &max_output_len 682
-# summary: [[1, 1, 576, 640], [1, 682]]
-
-logger:
- wandb:
- tags: ${tags}
-
-criterion:
- ignore_index: *ignore_index
- # label_smoothing: 0.05
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizer:
- _target_: adan_pytorch.Adan
- lr: 3.0e-4
- betas: [0.02, 0.08, 0.01]
- weight_decay: 0.02
-
-lr_scheduler:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.8
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/cer
-
-datamodule:
- batch_size: 4
- train_fraction: 0.95
-
-network:
- _target_: text_recognizer.network.ConvTransformer
- encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [1, 2, 4, 8, 32]
- depths: [2, 3, 3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
- attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: &dim 512
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: *dim
- mult: 2
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *dim
- axial_shape: [18, 80]
- decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- dim: *dim
- num_classes: 58
- pad_index: *ignore_index
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *dim
- depth: 6
- block:
- _target_: "text_recognizer.network.transformer.decoder_block.\
- DecoderBlock"
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: &dim_head 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: *dim_head
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: *dim_head
-
-trainer:
- gradient_clip_val: 1.0
- max_epochs: *epochs
- accumulate_grad_batches: 2
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
index f57eead..f3049ea 100644
--- a/training/conf/experiment/vit_lines.yaml
+++ b/training/conf/experiment/vit_lines.yaml
@@ -4,13 +4,13 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- - override /network: null
+ - override /network: vit_lines
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
tags: [lines, vit]
-epochs: &epochs 256
+epochs: &epochs 128
ignore_index: &ignore_index 3
# summary: [[1, 1, 56, 1024], [1, 89]]
@@ -59,45 +59,6 @@ datamodule:
batch_size: 16
train_fraction: 0.95
-network:
- _target_: text_recognizer.network.vit.VisionTransformer
- image_height: 56
- image_width: 1024
- patch_height: 28
- patch_width: 32
- dim: &dim 1024
- num_classes: &num_classes 58
- encoder:
- _target_: text_recognizer.network.transformer.encoder.Encoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.network.transformer.decoder.Decoder
- dim: *dim
- inner_dim: 2048
- heads: 16
- dim_head: 64
- depth: 4
- dropout_rate: 0.0
- token_embedding:
- _target_: "text_recognizer.network.transformer.embedding.token.\
- TokenEmbedding"
- num_tokens: *num_classes
- dim: *dim
- use_l2: true
- pos_embedding:
- _target_: "text_recognizer.network.transformer.embedding.absolute.\
- AbsolutePositionalEmbedding"
- dim: *dim
- max_length: 89
- use_l2: true
- tie_embeddings: true
- pad_index: 3
-
model:
max_output_len: 89
@@ -105,7 +66,7 @@ trainer:
fast_dev_run: false
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 4
+ accumulate_grad_batches: 1
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0