diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:09:19 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-05 00:09:19 +0200 |
commit | 12abf17cd7c31ae4599be366505a4423fbba4044 (patch) | |
tree | 996e5d549ebbb7d22f5acfdcd321bddec77f98d1 | |
parent | 16e2e420e077253c3b2bc414283281fea557717d (diff) |
Update perceiver conf
-rw-r--r-- | training/conf/experiment/conv_perceiver_lines.yaml | 76 | ||||
-rw-r--r-- | training/conf/network/conv_perceiver.yaml | 23 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 2 |
3 files changed, 92 insertions, 9 deletions
diff --git a/training/conf/experiment/conv_perceiver_lines.yaml b/training/conf/experiment/conv_perceiver_lines.yaml new file mode 100644 index 0000000..26fe232 --- /dev/null +++ b/training/conf/experiment/conv_perceiver_lines.yaml @@ -0,0 +1,76 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: conv_perceiver + - override /model: lit_perceiver + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines, perceiver] +epochs: &epochs 260 +ignore_index: &ignore_index 3 +num_classes: &num_classes 57 +max_output_len: &max_output_len 89 +summary: [[1, 1, 56, 1024]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.1 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: adan_pytorch.Adan + lr: 1.0e-4 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + input_dims: [1, 1, 56, 1024] + num_classes: *num_classes + pad_index: *ignore_index + encoder: + depth: 5 + decoder: + depth: 6 + +model: + max_output_len: *max_output_len + +trainer: + gradient_clip_val: 1.0 + stochastic_weight_avg: true + max_epochs: *epochs + accumulate_grad_batches: 1 diff --git a/training/conf/network/conv_perceiver.yaml b/training/conf/network/conv_perceiver.yaml index e6906fa..2e12db9 100644 --- a/training/conf/network/conv_perceiver.yaml +++ b/training/conf/network/conv_perceiver.yaml @@ -1,9 +1,10 @@ _target_: text_recognizer.networks.ConvPerceiver input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 144 +hidden_dim: &hidden_dim 128 num_classes: &num_classes 58 -queries_dim: &queries_dim 16 -max_length: 89 +max_length: &max_length 89 +num_queries: *max_length +queries_dim: &queries_dim 64 pad_index: 3 encoder: _target_: text_recognizer.networks.EfficientNet @@ -15,16 +16,22 @@ encoder: out_channels: *hidden_dim decoder: _target_: text_recognizer.networks.perceiver.PerceiverIO - dim: *hidden_dim + dim: 192 cross_heads: 1 cross_head_dim: 64 num_latents: 256 latent_dim: 512 latent_heads: 8 depth: 6 - queries_dim: *queries_dim + queries_dim: 128 logits_dim: *num_classes pixel_embedding: - _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding - dim: *hidden_dim - shape: [3, 64] + _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage + dim: 64 + axial_shape: [3, 64] + axial_dims: [32, 32] +query_pos_emb: + _target_: text_recognizer.networks.transformer.embeddings.absolute.AbsolutePositionalEmbedding + dim: 64 + max_seq_len: *max_length + l2norm_embed: true diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index c4cf13e..f618ba1 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,6 +1,6 @@ _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 96 +hidden_dim: &hidden_dim 144 num_classes: 58 pad_index: 3 encoder: |