summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
commit12abf17cd7c31ae4599be366505a4423fbba4044 (patch)
tree996e5d549ebbb7d22f5acfdcd321bddec77f98d1 /training/conf/experiment
parent16e2e420e077253c3b2bc414283281fea557717d (diff)
Update perceiver conf
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_perceiver_lines.yaml76
1 files changed, 76 insertions, 0 deletions
diff --git a/training/conf/experiment/conv_perceiver_lines.yaml b/training/conf/experiment/conv_perceiver_lines.yaml
new file mode 100644
index 0000000..26fe232
--- /dev/null
+++ b/training/conf/experiment/conv_perceiver_lines.yaml
@@ -0,0 +1,76 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: conv_perceiver
+ - override /model: lit_perceiver
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines, perceiver]
+epochs: &epochs 260
+ignore_index: &ignore_index 3
+num_classes: &num_classes 57
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024]]
+
+logger:
+ wandb:
+ tags: ${tags}
+
+criterion:
+ ignore_index: *ignore_index
+ # label_smoothing: 0.1
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: adan_pytorch.Adan
+ lr: 1.0e-4
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ input_dims: [1, 1, 56, 1024]
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ depth: 5
+ decoder:
+ depth: 6
+
+model:
+ max_output_len: *max_output_len
+
+trainer:
+ gradient_clip_val: 1.0
+ stochastic_weight_avg: true
+ max_epochs: *epochs
+ accumulate_grad_batches: 1