summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-27 18:19:47 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-27 18:19:47 +0200
commit80ed24eec6b62176587a411dff4f4e46e125e696 (patch)
tree8553c80ca3ba6c5ee3bc2d2596cd499fa1253486 /training/conf/experiment
parent9dba17704492afed6d9c134245f9528eb82c9918 (diff)
Add vq configs
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/vq_transformer_lines.yaml81
1 files changed, 81 insertions, 0 deletions
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml
new file mode 100644
index 0000000..dbd8a3b
--- /dev/null
+++ b/training/conf/experiment/vq_transformer_lines.yaml
@@ -0,0 +1,81 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: vq_transformer
+ - override /model: lit_vq_transformer
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines]
+epochs: &epochs 200
+ignore_index: &ignore_index 3
+num_classes: &num_classes 57
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024], [1, 89]]
+
+logger:
+ wandb:
+ tags: ${tags}
+ # id: 342qvr1p
+
+criterion:
+ ignore_index: *ignore_index
+ label_smoothing: 0.05
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: torch.optim.RAdam
+ lr: 3.0e-4
+ betas: [0.9, 0.999]
+ weight_decay: 0
+ eps: 1.0e-8
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ input_dims: [1, 1, 56, 1024]
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ depth: 5
+ decoder:
+ depth: 6
+ pixel_embedding:
+ shape: [1, 127]
+
+model:
+ max_output_len: *max_output_len
+ vq_loss_weight: 0.1
+
+trainer:
+ gradient_clip_val: 1.0
+ max_epochs: *epochs
+ accumulate_grad_batches: 1
+ # resume_from_checkpoint: /home/aktersnurra/projects/text-recognizer/training/logs/runs/2022-06-27/00-37-40/checkpoints/last.ckpt