summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_paragraphs.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-03 01:13:37 +0200
commita65d3ec18a5541cec5297769f1027422975a62bc (patch)
tree08e5e22f76db2449d265476f5fb42c5ea64a2776 /training/conf/experiment/conv_transformer_paragraphs.yaml
parente4d618443808f0931bbef0b9e10a2c2a215281a5 (diff)
Update confs and callbacks
Diffstat (limited to 'training/conf/experiment/conv_transformer_paragraphs.yaml')
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml132
1 files changed, 0 insertions, 132 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
deleted file mode 100644
index 9df2ea9..0000000
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ /dev/null
@@ -1,132 +0,0 @@
-# @package _global_
-
-defaults:
- - override /criterion: cross_entropy
- - override /callbacks: htr
- - override /datamodule: iam_extended_paragraphs
- - override /network: null
- - override /model: lit_transformer
- - override /lr_scheduler: null
- - override /optimizer: null
-
-tags: [paragraphs]
-epochs: &epochs 256
-ignore_index: &ignore_index 3
-# max_output_len: &max_output_len 682
-# summary: [[1, 1, 576, 640], [1, 682]]
-
-logger:
- wandb:
- tags: ${tags}
-
-criterion:
- ignore_index: *ignore_index
- # label_smoothing: 0.05
-
-callbacks:
- stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.75
- swa_lrs: 1.0e-5
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
-
-optimizer:
- _target_: adan_pytorch.Adan
- lr: 3.0e-4
- betas: [0.02, 0.08, 0.01]
- weight_decay: 0.02
-
-lr_scheduler:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.8
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
- verbose: false
- interval: epoch
- monitor: val/cer
-
-datamodule:
- batch_size: 4
- train_fraction: 0.95
-
-network:
- _target_: text_recognizer.network.ConvTransformer
- encoder:
- _target_: text_recognizer.network.image_encoder.ImageEncoder
- encoder:
- _target_: text_recognizer.network.convnext.ConvNext
- dim: 16
- dim_mults: [1, 2, 4, 8, 32]
- depths: [2, 3, 3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
- attn:
- _target_: text_recognizer.network.convnext.TransformerBlock
- attn:
- _target_: text_recognizer.network.convnext.Attention
- dim: &dim 512
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.network.convnext.FeedForward
- dim: *dim
- mult: 2
- pixel_embedding:
- _target_: "text_recognizer.network.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *dim
- axial_shape: [18, 80]
- decoder:
- _target_: text_recognizer.network.text_decoder.TextDecoder
- dim: *dim
- num_classes: 58
- pad_index: *ignore_index
- decoder:
- _target_: text_recognizer.network.transformer.Decoder
- dim: *dim
- depth: 6
- block:
- _target_: "text_recognizer.network.transformer.decoder_block.\
- DecoderBlock"
- self_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: &dim_head 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- cross_attn:
- _target_: text_recognizer.network.transformer.Attention
- dim: *dim
- num_heads: 8
- dim_head: *dim_head
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.network.transformer.RMSNorm
- dim: *dim
- ff:
- _target_: text_recognizer.network.transformer.FeedForward
- dim: *dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- rotary_embedding:
- _target_: text_recognizer.network.transformer.RotaryEmbedding
- dim: *dim_head
-
-trainer:
- gradient_clip_val: 1.0
- max_epochs: *epochs
- accumulate_grad_batches: 2
- limit_train_batches: 1.0
- limit_val_batches: 1.0
- limit_test_batches: 1.0