diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/conf/experiment/convformer_lines.yaml | 58 | ||||
| -rw-r--r-- | training/conf/experiment/mammut_lines.yaml | 5 | ||||
| -rw-r--r-- | training/conf/experiment/vit_lines.yaml | 3 | ||||
| -rw-r--r-- | training/conf/network/convformer_lines.yaml | 31 | ||||
| -rw-r--r-- | training/conf/network/convnext.yaml | 16 | ||||
| -rw-r--r-- | training/conf/network/mammut_cvit_lines.yaml | 51 | ||||
| -rw-r--r-- | training/conf/network/mammut_lines.yaml | 19 | ||||
| -rw-r--r-- | training/conf/network/vit_lines.yaml | 56 | 
8 files changed, 189 insertions, 50 deletions
diff --git a/training/conf/experiment/convformer_lines.yaml b/training/conf/experiment/convformer_lines.yaml new file mode 100644 index 0000000..f573433 --- /dev/null +++ b/training/conf/experiment/convformer_lines.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +defaults: +  - override /criterion: cross_entropy +  - override /callbacks: htr +  - override /datamodule: iam_lines +  - override /network: convformer_lines +  - override /model: lit_transformer +  - override /lr_scheduler: cosine_annealing +  - override /optimizer: adan + +tags: [lines, vit] +epochs: &epochs 320 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: +  wandb: +    tags: ${tags} + +criterion: +  ignore_index: *ignore_index +  # label_smoothing: 0.05 + + +decoder: +  max_output_len: 89 + +optimizer: +  lr: 1.0e-3 + +# callbacks: +#   stochastic_weight_averaging: +#     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +#     swa_epoch_start: 0.75 +#     swa_lrs: 1.0e-5 +#     annealing_epochs: 10 +#     annealing_strategy: cos +#     device: null + +lr_scheduler: +  T_max: *epochs + +datamodule: +  batch_size: 8 +  train_fraction: 0.95 + +model: +  max_output_len: 89 + +trainer: +  fast_dev_run: false +  gradient_clip_val: 1.0 +  max_epochs: *epochs +  accumulate_grad_batches: 1 +  limit_train_batches: 1.0 +  limit_val_batches: 1.0 +  limit_test_batches: 1.0 diff --git a/training/conf/experiment/mammut_lines.yaml b/training/conf/experiment/mammut_lines.yaml index e74e219..eb6f765 100644 --- a/training/conf/experiment/mammut_lines.yaml +++ b/training/conf/experiment/mammut_lines.yaml @@ -39,12 +39,15 @@ lr_scheduler:    T_max: *epochs  datamodule: -  batch_size: 8 +  batch_size: 16    train_fraction: 0.95  model:    max_output_len: 89 +optimizer: +  lr: 1.0e-3 +  trainer:    fast_dev_run: false    gradient_clip_val: 1.0 diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml index 08ed481..2f7731e 100644 --- a/training/conf/experiment/vit_lines.yaml +++ b/training/conf/experiment/vit_lines.yaml @@ -26,6 +26,9 @@ criterion:  decoder:    max_output_len: 89 +optim4izer: +  lr: 1.0e-3 +  # callbacks:  #   stochastic_weight_averaging:  #     _target_: pytorch_lightning.callbacks.StochasticWeightAveraging diff --git a/training/conf/network/convformer_lines.yaml b/training/conf/network/convformer_lines.yaml new file mode 100644 index 0000000..ef9c831 --- /dev/null +++ b/training/conf/network/convformer_lines.yaml @@ -0,0 +1,31 @@ +_target_: text_recognizer.network.convformer.Convformer +image_height: 7 +image_width: 128 +patch_height: 1 +patch_width: 1 +dim: &dim 512 +num_classes: &num_classes 57 +encoder: +  _target_: text_recognizer.network.convnext.convnext.ConvNext +  dim: 16 +  dim_mults: [2, 8, 32] +  depths: [2, 2, 2] +  attn: null +decoder: +  _target_: text_recognizer.network.transformer.decoder.Decoder +  dim: *dim +  ff_mult: 4 +  heads: 12 +  dim_head: 64 +  depth: 6 +  dropout_rate: 0. +  one_kv_head: true +token_embedding: +  _target_: "text_recognizer.network.transformer.embedding.token.\ +    TokenEmbedding" +  num_tokens: *num_classes +  dim: *dim +  use_l2: true +tie_embeddings: false +pad_index: 3 +channels: 512 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 40343a7..bcbc78e 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,15 +1,15 @@  _target_: text_recognizer.network.convnext.convnext.ConvNext  dim: 8 -dim_mults: [2, 8] -depths: [2, 2] +dim_mults: [2, 8, 8, 8] +depths: [2, 2, 2, 2]  attn:    _target_: text_recognizer.network.convnext.transformer.Transformer -  attn: -    _target_: text_recognizer.network.convnext.transformer.Attention -    dim: 64 -    heads: 4 -    dim_head: 64 -    scale: 8 +  attn: null +    # _target_: text_recognizer.network.convnext.transformer.Attention +    # dim: 64 +    # heads: 4 +    # dim_head: 64 +    # scale: 8    ff:      _target_: text_recognizer.network.convnext.transformer.FeedForward      dim: 64 diff --git a/training/conf/network/mammut_cvit_lines.yaml b/training/conf/network/mammut_cvit_lines.yaml new file mode 100644 index 0000000..75fcccb --- /dev/null +++ b/training/conf/network/mammut_cvit_lines.yaml @@ -0,0 +1,51 @@ +_target_: text_recognizer.network.mammut.MaMMUT +encoder: +  _target_: text_recognizer.network.cvit.CVit +  image_height: 7 +  image_width: 128 +  patch_height: 7 +  patch_width: 1 +  dim: &dim 512 +  encoder: +    _target_: text_recognizer.network.transformer.encoder.Encoder +    dim: *dim +    heads: 8 +    dim_head: 64 +    ff_mult: 4 +    depth: 2 +    dropout_rate: 0.5 +    use_rotary_emb: true +    one_kv_head: true +  stem: +    _target_: text_recognizer.network.convnext.convnext.ConvNext +    dim: 16 +    dim_mults: [2, 8, 32] +    depths: [2, 2, 4] +    attn: null +  channels: 512 +image_attn_pool: +  _target_: text_recognizer.network.transformer.attention.Attention +  dim: *dim +  heads: 8 +  causal: false +  dim_head: 64 +  ff_mult: 4 +  dropout_rate: 0.0 +  use_flash: true +  norm_context: true +  use_rotary_emb: false +  one_kv_head: true +decoder: +  _target_: text_recognizer.network.transformer.decoder.Decoder +  dim: *dim +  ff_mult: 4 +  heads: 8 +  dim_head: 64 +  depth: 6 +  dropout_rate: 0.5 +  one_kv_head: true +dim: *dim +dim_latent: *dim +num_tokens: 57 +pad_index: 3 +num_image_queries: 64 diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml index f1c73d0..0b27f09 100644 --- a/training/conf/network/mammut_lines.yaml +++ b/training/conf/network/mammut_lines.yaml @@ -4,17 +4,20 @@ encoder:    image_height: 56    image_width: 1024    patch_height: 56 -  patch_width: 8 +  patch_width: 2    dim: &dim 512    encoder:      _target_: text_recognizer.network.transformer.encoder.Encoder      dim: *dim -    heads: 12 +    heads: 16      dim_head: 64      ff_mult: 4      depth: 6 -    dropout_rate: 0.1 +    dropout_rate: 0. +    use_rotary_emb: true +    one_kv_head: true    channels: 1 +  patch_dropout: 0.5  image_attn_pool:    _target_: text_recognizer.network.transformer.attention.Attention    dim: *dim @@ -25,7 +28,8 @@ image_attn_pool:    dropout_rate: 0.0    use_flash: true    norm_context: true -  rotary_emb: null +  use_rotary_emb: false +  one_kv_head: true  decoder:    _target_: text_recognizer.network.transformer.decoder.Decoder    dim: *dim @@ -33,9 +37,10 @@ decoder:    heads: 12    dim_head: 64    depth: 6 -  dropout_rate: 0.1 +  dropout_rate: 0. +  one_kv_head: true  dim: *dim  dim_latent: *dim -num_tokens: 58 +num_tokens: 57  pad_index: 3 -num_image_queries: 256 +num_image_queries: 128 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index 638dae1..a8045c2 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -1,51 +1,39 @@ -_target_: text_recognizer.network.convformer.Convformer -image_height: 7 -image_width: 128 -patch_height: 7 -patch_width: 1 +_target_: text_recognizer.network.transformer.transformer.Transformer  dim: &dim 768 -num_classes: &num_classes 58 +num_classes: &num_classes 57  encoder: -  _target_: text_recognizer.network.transformer.encoder.Encoder +  _target_: text_recognizer.network.transformer.vit.Vit +  image_height: 56 +  image_width: 1024 +  patch_height: 56 +  patch_width: 8    dim: *dim -  inner_dim: 3072 -  ff_mult: 4 -  heads: 12 -  dim_head: 64 -  depth: 4 -  dropout_rate: 0.1 +  encoder: +    _target_: text_recognizer.network.transformer.encoder.Encoder +    dim: *dim +    heads: 16 +    dim_head: 64 +    ff_mult: 4 +    depth: 6 +    dropout_rate: 0. +    use_rotary_emb: true +    one_kv_head: false +  channels: 1 +  patch_dropout: 0.4  decoder:    _target_: text_recognizer.network.transformer.decoder.Decoder    dim: *dim -  inner_dim: 3072    ff_mult: 4    heads: 12    dim_head: 64    depth: 6 -  dropout_rate: 0.1 +  dropout_rate: 0. +  one_kv_head: false  token_embedding:    _target_: "text_recognizer.network.transformer.embedding.token.\      TokenEmbedding"    num_tokens: *num_classes    dim: *dim    use_l2: true -tie_embeddings: true +tie_embeddings: false  pad_index: 3 -channels: 64 -stem: -  _target_: text_recognizer.network.convnext.convnext.ConvNext -  dim: 8 -  dim_mults: [2, 8, 8] -  depths: [2, 2, 2] -  attn: null -    # _target_: text_recognizer.network.convnext.transformer.Transformer -    # attn: -    #   _target_: text_recognizer.network.convnext.transformer.Attention -    #   dim: 64 -    #   heads: 4 -    #   dim_head: 64 -    #   scale: 8 -    # ff: -    #   _target_: text_recognizer.network.convnext.transformer.FeedForward -    #   dim: 64 -    #   mult: 4  |