From 6968572c1a21394b88a29f675b17b9698784a898 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:39 +0200 Subject: Update training stuff --- training/conf/callbacks/wandb/watch.yaml | 2 +- training/conf/config.yaml | 2 +- training/conf/decoder/greedy.yaml | 2 +- .../conf/experiment/conv_transformer_lines.yaml | 30 +++--- .../experiment/conv_transformer_paragraphs.yaml | 30 +++--- training/conf/experiment/vit_lines.yaml | 113 +++++++++++++++++++++ training/conf/logger/csv.yaml | 4 + training/conf/logger/wandb.yaml | 2 +- training/conf/model/lit_transformer.yaml | 2 +- training/conf/network/conv_transformer.yaml | 26 ++--- training/conf/network/convnext.yaml | 8 +- training/conf/network/vit_lines.yaml | 37 +++++++ training/conf/trainer/default.yaml | 6 +- 13 files changed, 208 insertions(+), 56 deletions(-) create mode 100644 training/conf/experiment/vit_lines.yaml create mode 100644 training/conf/logger/csv.yaml create mode 100644 training/conf/network/vit_lines.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml index bada03b..660ae47 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb/watch.yaml @@ -1,5 +1,5 @@ watch_model: _target_: callbacks.wandb_callbacks.WatchModel - log: gradients + log_params: gradients log_freq: 100 log_graph: true diff --git a/training/conf/config.yaml b/training/conf/config.yaml index e57a8a8..8a1317c 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -13,7 +13,7 @@ defaults: - network: conv_transformer - optimizer: radam - trainer: default - - experiment: null + - experiment: vit_lines seed: 4711 tune: false diff --git a/training/conf/decoder/greedy.yaml b/training/conf/decoder/greedy.yaml index 1d1a131..a88b5a6 100644 --- a/training/conf/decoder/greedy.yaml +++ b/training/conf/decoder/greedy.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.models.greedy_decoder.GreedyDecoder +_target_: text_recognizer.model.greedy_decoder.GreedyDecoder max_output_len: 682 diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 948968a..12fe701 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -56,70 +56,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 32] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [7, 128] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head model: diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index ff931cc..9df2ea9 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -57,70 +57,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [1, 2, 4, 8, 32] depths: [2, 3, 3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [18, 80] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head trainer: diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml new file mode 100644 index 0000000..e2ddebf --- /dev/null +++ b/training/conf/experiment/vit_lines.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: lit_transformer + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines, vit] +epochs: &epochs 64 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: adan_pytorch.Adan + lr: 3.0e-4 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + _target_: text_recognizer.network.vit.VisionTransformer + image_height: 56 + image_width: 1024 + patch_height: 28 + patch_width: 32 + dim: &dim 1024 + num_classes: &num_classes 58 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true + pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true + tie_embeddings: false + pad_index: 3 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_val_batches: .02 + limit_test_batches: .02 + limit_train_batches: 1.0 + # limit_val_batches: 1.0 + # limit_test_batches: 1.0 diff --git a/training/conf/logger/csv.yaml b/training/conf/logger/csv.yaml new file mode 100644 index 0000000..9fa6cad --- /dev/null +++ b/training/conf/logger/csv.yaml @@ -0,0 +1,4 @@ +csv: + _target_: pytorch_lightning.loggers.CSVLogger + name: null + save_dir: "." diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml index 081ebeb..ba3218a 100644 --- a/training/conf/logger/wandb.yaml +++ b/training/conf/logger/wandb.yaml @@ -1,5 +1,5 @@ wandb: - _target_: pytorch_lightning.loggers.wandb.WandbLogger + _target_: pytorch_lightning.loggers.WandbLogger project: text-recognizer name: null save_dir: "." diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index e6af035..533f8f3 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.models.LitTransformer +_target_: text_recognizer.model.transformer.LitTransformer max_output_len: 682 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 016adbb..1e03946 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,58 +1,58 @@ -_target_: text_recognizer.networks.ConvTransformer +_target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: &hidden_dim 128 axial_shape: [7, 128] axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder hidden_dim: *hidden_dim num_classes: 58 pad_index: 3 decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *hidden_dim depth: 10 block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + _target_: text_recognizer.network.transformer.decoder_block.DecoderBlock self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *hidden_dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *hidden_dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: 64 token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + _target_: "text_recognizer.network.transformer.embeddings.fourier.\ PositionalEncoding" dim: *hidden_dim dropout_rate: 0.1 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 63ad424..904bd56 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,17 +1,17 @@ -_target_: text_recognizer.networks.convnext.ConvNext +_target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: 128 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: 128 mult: 4 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml new file mode 100644 index 0000000..35f83c3 --- /dev/null +++ b/training/conf/network/vit_lines.yaml @@ -0,0 +1,37 @@ +_target_: text_recognizer.network.vit.VisionTransformer +image_height: 56 +image_width: 1024 +patch_height: 28 +patch_width: 32 +dim: &dim 256 +num_classes: &num_classes 57 +encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true +pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true +tie_embeddings: true +pad_index: 3 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 6112cd8..2e593e8 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -1,14 +1,12 @@ _target_: pytorch_lightning.Trainer -auto_scale_batch_size: binsearch -auto_lr_find: false gradient_clip_val: 0.5 fast_dev_run: false -gpus: 1 +accelerator: gpu +devices: 1 precision: 16 max_epochs: 256 limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 -resume_from_checkpoint: null accumulate_grad_batches: 1 overfit_batches: 0 -- cgit v1.2.3-70-g09d2