From 56dc112cfb649217cd624b4ff305e2db83a383b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:15:26 +0200 Subject: Update configs --- training/conf/callbacks/wandb/watch.yaml | 2 +- training/conf/decoder/greedy.yaml | 2 +- training/conf/experiment/mammut_lines.yaml | 55 ++++++++++++++++++++++++ training/conf/experiment/vit_lines.yaml | 43 ++++++------------ training/conf/logger/csv.yaml | 4 -- training/conf/lr_scheduler/cosine_annealing.yaml | 2 +- training/conf/model/lit_mammut.yaml | 4 ++ training/conf/network/convnext.yaml | 16 +++++++ training/conf/network/mammut_lines.yaml | 41 ++++++++++++++++++ training/conf/network/vit_lines.yaml | 54 ++++++++++++++--------- training/conf/optimizer/adamw.yaml | 5 +++ training/conf/optimizer/adan.yaml | 4 ++ training/conf/optimizer/lion.yaml | 5 +++ 13 files changed, 180 insertions(+), 57 deletions(-) create mode 100644 training/conf/experiment/mammut_lines.yaml delete mode 100644 training/conf/logger/csv.yaml create mode 100644 training/conf/model/lit_mammut.yaml create mode 100644 training/conf/network/convnext.yaml create mode 100644 training/conf/network/mammut_lines.yaml create mode 100644 training/conf/optimizer/adamw.yaml create mode 100644 training/conf/optimizer/adan.yaml create mode 100644 training/conf/optimizer/lion.yaml diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml index 1f60978..945b624 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb/watch.yaml @@ -2,4 +2,4 @@ watch_model: _target_: callbacks.wandb.WatchModel log_params: gradients log_freq: 100 - log_graph: true + log_graph: false diff --git a/training/conf/decoder/greedy.yaml b/training/conf/decoder/greedy.yaml index a88b5a6..44ef53e 100644 --- a/training/conf/decoder/greedy.yaml +++ b/training/conf/decoder/greedy.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.model.greedy_decoder.GreedyDecoder +_target_: text_recognizer.decoder.greedy_decoder.GreedyDecoder max_output_len: 682 diff --git a/training/conf/experiment/mammut_lines.yaml b/training/conf/experiment/mammut_lines.yaml new file mode 100644 index 0000000..e74e219 --- /dev/null +++ b/training/conf/experiment/mammut_lines.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: mammut_lines + - override /model: lit_mammut + - override /lr_scheduler: cosine_annealing + - override /optimizer: adan + +tags: [lines, vit] +epochs: &epochs 320 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +# callbacks: +# stochastic_weight_averaging: +# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +# swa_epoch_start: 0.75 +# swa_lrs: 1.0e-5 +# annealing_epochs: 10 +# annealing_strategy: cos +# device: null + +lr_scheduler: + T_max: *epochs + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml index f3049ea..08ed481 100644 --- a/training/conf/experiment/vit_lines.yaml +++ b/training/conf/experiment/vit_lines.yaml @@ -6,11 +6,11 @@ defaults: - override /datamodule: iam_lines - override /network: vit_lines - override /model: lit_transformer - - override /lr_scheduler: null - - override /optimizer: null + - override /lr_scheduler: cosine_annealing + - override /optimizer: adan tags: [lines, vit] -epochs: &epochs 128 +epochs: &epochs 320 ignore_index: &ignore_index 3 # summary: [[1, 1, 56, 1024], [1, 89]] @@ -26,37 +26,20 @@ criterion: decoder: max_output_len: 89 -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizer: - _target_: adan_pytorch.Adan - lr: 3.0e-4 - betas: [0.02, 0.08, 0.01] - weight_decay: 0.02 +# callbacks: +# stochastic_weight_averaging: +# _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +# swa_epoch_start: 0.75 +# swa_lrs: 1.0e-5 +# annealing_epochs: 10 +# annealing_strategy: cos +# device: null lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.8 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/cer + T_max: *epochs datamodule: - batch_size: 16 + batch_size: 8 train_fraction: 0.95 model: diff --git a/training/conf/logger/csv.yaml b/training/conf/logger/csv.yaml deleted file mode 100644 index 9fa6cad..0000000 --- a/training/conf/logger/csv.yaml +++ /dev/null @@ -1,4 +0,0 @@ -csv: - _target_: pytorch_lightning.loggers.CSVLogger - name: null - save_dir: "." diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml index e8364f0..36e87d4 100644 --- a/training/conf/lr_scheduler/cosine_annealing.yaml +++ b/training/conf/lr_scheduler/cosine_annealing.yaml @@ -2,6 +2,6 @@ _target_: torch.optim.lr_scheduler.CosineAnnealingLR T_max: 256 eta_min: 0.0 last_epoch: -1 - +verbose: false interval: epoch monitor: val/loss diff --git a/training/conf/model/lit_mammut.yaml b/training/conf/model/lit_mammut.yaml new file mode 100644 index 0000000..2495692 --- /dev/null +++ b/training/conf/model/lit_mammut.yaml @@ -0,0 +1,4 @@ +_target_: text_recognizer.model.mammut.LitMaMMUT +max_output_len: 682 +caption_loss_weight: 1.0 +contrastive_loss_weight: 1.0 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml new file mode 100644 index 0000000..40343a7 --- /dev/null +++ b/training/conf/network/convnext.yaml @@ -0,0 +1,16 @@ +_target_: text_recognizer.network.convnext.convnext.ConvNext +dim: 8 +dim_mults: [2, 8] +depths: [2, 2] +attn: + _target_: text_recognizer.network.convnext.transformer.Transformer + attn: + _target_: text_recognizer.network.convnext.transformer.Attention + dim: 64 + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.network.convnext.transformer.FeedForward + dim: 64 + mult: 4 diff --git a/training/conf/network/mammut_lines.yaml b/training/conf/network/mammut_lines.yaml new file mode 100644 index 0000000..f1c73d0 --- /dev/null +++ b/training/conf/network/mammut_lines.yaml @@ -0,0 +1,41 @@ +_target_: text_recognizer.network.mammut.MaMMUT +encoder: + _target_: text_recognizer.network.vit.Vit + image_height: 56 + image_width: 1024 + patch_height: 56 + patch_width: 8 + dim: &dim 512 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + heads: 12 + dim_head: 64 + ff_mult: 4 + depth: 6 + dropout_rate: 0.1 + channels: 1 +image_attn_pool: + _target_: text_recognizer.network.transformer.attention.Attention + dim: *dim + heads: 8 + causal: false + dim_head: 64 + ff_mult: 4 + dropout_rate: 0.0 + use_flash: true + norm_context: true + rotary_emb: null +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + ff_mult: 4 + heads: 12 + dim_head: 64 + depth: 6 + dropout_rate: 0.1 +dim: *dim +dim_latent: *dim +num_tokens: 58 +pad_index: 3 +num_image_queries: 256 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml index f32cb83..638dae1 100644 --- a/training/conf/network/vit_lines.yaml +++ b/training/conf/network/vit_lines.yaml @@ -1,37 +1,51 @@ -_target_: text_recognizer.network.vit.VisionTransformer -image_height: 56 -image_width: 1024 -patch_height: 28 -patch_width: 32 -dim: &dim 1024 +_target_: text_recognizer.network.convformer.Convformer +image_height: 7 +image_width: 128 +patch_height: 7 +patch_width: 1 +dim: &dim 768 num_classes: &num_classes 58 encoder: _target_: text_recognizer.network.transformer.encoder.Encoder dim: *dim - inner_dim: 2048 - heads: 16 + inner_dim: 3072 + ff_mult: 4 + heads: 12 dim_head: 64 - depth: 6 - dropout_rate: 0.0 + depth: 4 + dropout_rate: 0.1 decoder: _target_: text_recognizer.network.transformer.decoder.Decoder dim: *dim - inner_dim: 2048 - heads: 16 + inner_dim: 3072 + ff_mult: 4 + heads: 12 dim_head: 64 depth: 6 - dropout_rate: 0.0 + dropout_rate: 0.1 token_embedding: _target_: "text_recognizer.network.transformer.embedding.token.\ TokenEmbedding" num_tokens: *num_classes dim: *dim use_l2: true -pos_embedding: - _target_: "text_recognizer.network.transformer.embedding.absolute.\ - AbsolutePositionalEmbedding" - dim: *dim - max_length: 89 - use_l2: true -tie_embeddings: false +tie_embeddings: true pad_index: 3 +channels: 64 +stem: + _target_: text_recognizer.network.convnext.convnext.ConvNext + dim: 8 + dim_mults: [2, 8, 8] + depths: [2, 2, 2] + attn: null + # _target_: text_recognizer.network.convnext.transformer.Transformer + # attn: + # _target_: text_recognizer.network.convnext.transformer.Attention + # dim: 64 + # heads: 4 + # dim_head: 64 + # scale: 8 + # ff: + # _target_: text_recognizer.network.convnext.transformer.FeedForward + # dim: 64 + # mult: 4 diff --git a/training/conf/optimizer/adamw.yaml b/training/conf/optimizer/adamw.yaml new file mode 100644 index 0000000..568c67b --- /dev/null +++ b/training/conf/optimizer/adamw.yaml @@ -0,0 +1,5 @@ +_target_: torch.optim.AdamW +lr: 1.5e-4 +betas: [0.95, 0.98] +weight_decay: 1.0e-2 +eps: 1.0e-6 diff --git a/training/conf/optimizer/adan.yaml b/training/conf/optimizer/adan.yaml new file mode 100644 index 0000000..6950a86 --- /dev/null +++ b/training/conf/optimizer/adan.yaml @@ -0,0 +1,4 @@ +_target_: adan_pytorch.Adan +lr: 2.0e-4 +betas: [0.02, 0.08, 0.01] +weight_decay: 0.02 diff --git a/training/conf/optimizer/lion.yaml b/training/conf/optimizer/lion.yaml new file mode 100644 index 0000000..cbf386a --- /dev/null +++ b/training/conf/optimizer/lion.yaml @@ -0,0 +1,5 @@ +_target_: lion_pytorch.Lion +lr: 5e-5 +betas: [0.95, 0.99] +weight_decay: 0.1 +use_triton: true -- cgit v1.2.3-70-g09d2