diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:19:47 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-27 18:19:47 +0200 |
commit | 80ed24eec6b62176587a411dff4f4e46e125e696 (patch) | |
tree | 8553c80ca3ba6c5ee3bc2d2596cd499fa1253486 | |
parent | 9dba17704492afed6d9c134245f9528eb82c9918 (diff) |
Add vq configs
-rw-r--r-- | training/conf/experiment/vq_transformer_lines.yaml | 81 | ||||
-rw-r--r-- | training/conf/model/lit_vq_transformer.yaml | 5 | ||||
-rw-r--r-- | training/conf/network/vq_transformer.yaml | 65 |
3 files changed, 151 insertions, 0 deletions
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml new file mode 100644 index 0000000..dbd8a3b --- /dev/null +++ b/training/conf/experiment/vq_transformer_lines.yaml @@ -0,0 +1,81 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: vq_transformer + - override /model: lit_vq_transformer + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines] +epochs: &epochs 200 +ignore_index: &ignore_index 3 +num_classes: &num_classes 57 +max_output_len: &max_output_len 89 +summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + # id: 342qvr1p + +criterion: + ignore_index: *ignore_index + label_smoothing: 0.05 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: torch.optim.RAdam + lr: 3.0e-4 + betas: [0.9, 0.999] + weight_decay: 0 + eps: 1.0e-8 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + input_dims: [1, 1, 56, 1024] + num_classes: *num_classes + pad_index: *ignore_index + encoder: + depth: 5 + decoder: + depth: 6 + pixel_embedding: + shape: [1, 127] + +model: + max_output_len: *max_output_len + vq_loss_weight: 0.1 + +trainer: + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + # resume_from_checkpoint: /home/aktersnurra/projects/text-recognizer/training/logs/runs/2022-06-27/00-37-40/checkpoints/last.ckpt diff --git a/training/conf/model/lit_vq_transformer.yaml b/training/conf/model/lit_vq_transformer.yaml new file mode 100644 index 0000000..4173151 --- /dev/null +++ b/training/conf/model/lit_vq_transformer.yaml @@ -0,0 +1,5 @@ +_target_: text_recognizer.models.LitVqTransformer +max_output_len: 682 +start_token: <s> +end_token: <e> +pad_token: <p> diff --git a/training/conf/network/vq_transformer.yaml b/training/conf/network/vq_transformer.yaml new file mode 100644 index 0000000..d62a4b7 --- /dev/null +++ b/training/conf/network/vq_transformer.yaml @@ -0,0 +1,65 @@ +_target_: text_recognizer.networks.VqTransformer +input_dims: [1, 1, 576, 640] +hidden_dim: &hidden_dim 144 +num_classes: 58 +pad_index: 3 +encoder: + _target_: text_recognizer.networks.EfficientNet + arch: b0 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + depth: 5 + out_channels: *hidden_dim +decoder: + _target_: text_recognizer.networks.transformer.Decoder + depth: 6 + block: + _target_: text_recognizer.networks.transformer.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 8 + dim_head: 64 + dropout_rate: &dropout_rate 0.4 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 8 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate +pixel_embedding: + _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + dim: *hidden_dim + shape: [18, 79] +quantizer: + _target_: text_recognizer.networks.quantizer.VectorQuantizer + input_dim: *hidden_dim + codebook: + _target_: text_recognizer.networks.quantizer.CosineSimilarityCodebook + dim: 16 + codebook_size: 64 + kmeans_init: true + kmeans_iters: 10 + decay: 0.8 + eps: 1.0e-5 + threshold_dead: 2 + temperature: 0.0 + commitment: 0.25 + ort_reg_weight: 10 + ort_reg_max_codes: 64 |