summaryrefslogtreecommitdiff
path: root/training/conf/experiment/vit_lines.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:19:39 +0200
commit6968572c1a21394b88a29f675b17b9698784a898 (patch)
treed89d1c5c2ec331d38dcb5b6a2dbbd72c9e355b8a /training/conf/experiment/vit_lines.yaml
parent49ca6ade1a19f7f9c702171537fe4be0dfcda66d (diff)
Update training stuff
Diffstat (limited to 'training/conf/experiment/vit_lines.yaml')
-rw-r--r--training/conf/experiment/vit_lines.yaml113
1 files changed, 113 insertions, 0 deletions
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml
new file mode 100644
index 0000000..e2ddebf
--- /dev/null
+++ b/training/conf/experiment/vit_lines.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: null
+ - override /model: lit_transformer
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines, vit]
+epochs: &epochs 64
+ignore_index: &ignore_index 3
+# summary: [[1, 1, 56, 1024], [1, 89]]
+
+logger:
+ wandb:
+ tags: ${tags}
+
+criterion:
+ ignore_index: *ignore_index
+ # label_smoothing: 0.05
+
+
+decoder:
+ max_output_len: 89
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: adan_pytorch.Adan
+ lr: 3.0e-4
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ _target_: text_recognizer.network.vit.VisionTransformer
+ image_height: 56
+ image_width: 1024
+ patch_height: 28
+ patch_width: 32
+ dim: &dim 1024
+ num_classes: &num_classes 58
+ encoder:
+ _target_: text_recognizer.network.transformer.encoder.Encoder
+ dim: *dim
+ inner_dim: 2048
+ heads: 16
+ dim_head: 64
+ depth: 4
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.network.transformer.decoder.Decoder
+ dim: *dim
+ inner_dim: 2048
+ heads: 16
+ dim_head: 64
+ depth: 4
+ dropout_rate: 0.0
+ token_embedding:
+ _target_: "text_recognizer.network.transformer.embedding.token.\
+ TokenEmbedding"
+ num_tokens: *num_classes
+ dim: *dim
+ use_l2: true
+ pos_embedding:
+ _target_: "text_recognizer.network.transformer.embedding.absolute.\
+ AbsolutePositionalEmbedding"
+ dim: *dim
+ max_length: 89
+ use_l2: true
+ tie_embeddings: false
+ pad_index: 3
+
+model:
+ max_output_len: 89
+
+trainer:
+ fast_dev_run: false
+ gradient_clip_val: 1.0
+ max_epochs: *epochs
+ accumulate_grad_batches: 1
+ limit_val_batches: .02
+ limit_test_batches: .02
+ limit_train_batches: 1.0
+ # limit_val_batches: 1.0
+ # limit_test_batches: 1.0