diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:11:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:11:11 +0200 |
commit | 0c578ab3e79c00389f6aab79329a7652f2ce6f22 (patch) | |
tree | 9b69c37a874e2e89fd4b2389f8c7a02bfc41cbb6 /training/conf/experiment/vqgan_htr_char_iam_lines.yaml | |
parent | 7b8705f382b1642cf171cf7fcd01295104b9deef (diff) |
Add new config files
Diffstat (limited to 'training/conf/experiment/vqgan_htr_char_iam_lines.yaml')
-rw-r--r-- | training/conf/experiment/vqgan_htr_char_iam_lines.yaml | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/training/conf/experiment/vqgan_htr_char_iam_lines.yaml b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml new file mode 100644 index 0000000..9f4791f --- /dev/null +++ b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml @@ -0,0 +1,90 @@ +# @package _global_ + +defaults: + - override /mapping: null + - override /criterion: null + - override /datamodule: null + - override /network: null + - override /model: null + - override /lr_schedulers: null + # - override /optimizers: null + + +criterion: + _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss + smoothing: 0.1 + ignore_index: 3 + +mapping: + _target_: text_recognizer.data.emnist_mapping.EmnistMapping + # extra_symbols: [ "\n" ] + +lr_schedulers: + network: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 512 + eta_min: 4.5e-6 + last_epoch: -1 + interval: epoch + monitor: val/loss + +datamodule: + _target_: text_recognizer.data.iam_lines.IAMLines + batch_size: 4 + num_workers: 12 + train_fraction: 0.8 + augment: false + pin_memory: false + + +# optimizers: +# - _target_: madgrad.MADGRAD +# lr: 2.0e-4 +# momentum: 0.9 +# weight_decay: 0 +# eps: 1.0e-7 +# parameters: network + +network: + _target_: text_recognizer.networks.vq_transformer.VqTransformer + input_dims: [1, 56, 1024] + encoder_dim: 32 + hidden_dim: 32 + dropout_rate: 0.1 + num_classes: 58 + pad_index: 3 + no_grad: true + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: 32 + depth: 4 + num_heads: 8 + attn_fn: text_recognizer.networks.transformer.attention.Attention + attn_kwargs: + dim_head: 32 + dropout_rate: 0.2 + norm_fn: torch.nn.LayerNorm + ff_fn: text_recognizer.networks.transformer.mlp.FeedForward + ff_kwargs: + dim_out: null + expansion_factor: 4 + glu: true + dropout_rate: 0.2 + cross_attend: true + pre_norm: true + rotary_emb: null + pretrained_encoder_path: "training/logs/runs/2021-09-26/23-27-57" + +model: + _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel + start_token: <s> + end_token: <e> + pad_token: <p> + max_output_len: 89 # 451 + alpha: 0.0 + +trainer: + max_epochs: 512 + # limit_train_batches: 0.1 + # limit_val_batches: 0.1 + # gradient_clip_val: 0.5 |