diff options
Diffstat (limited to 'training/conf/experiment/vit_lines.yaml')
-rw-r--r-- | training/conf/experiment/vit_lines.yaml | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml new file mode 100644 index 0000000..e2ddebf --- /dev/null +++ b/training/conf/experiment/vit_lines.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: lit_transformer + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines, vit] +epochs: &epochs 64 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: adan_pytorch.Adan + lr: 3.0e-4 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + _target_: text_recognizer.network.vit.VisionTransformer + image_height: 56 + image_width: 1024 + patch_height: 28 + patch_width: 32 + dim: &dim 1024 + num_classes: &num_classes 58 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true + pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true + tie_embeddings: false + pad_index: 3 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_val_batches: .02 + limit_test_batches: .02 + limit_train_batches: 1.0 + # limit_val_batches: 1.0 + # limit_test_batches: 1.0 |