From 6968572c1a21394b88a29f675b17b9698784a898 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:39 +0200 Subject: Update training stuff --- .../conf/experiment/conv_transformer_lines.yaml | 30 +++--- .../experiment/conv_transformer_paragraphs.yaml | 30 +++--- training/conf/experiment/vit_lines.yaml | 113 +++++++++++++++++++++ 3 files changed, 143 insertions(+), 30 deletions(-) create mode 100644 training/conf/experiment/vit_lines.yaml (limited to 'training/conf/experiment') diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 948968a..12fe701 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -56,70 +56,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 32] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [7, 128] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head model: diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index ff931cc..9df2ea9 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -57,70 +57,70 @@ datamodule: train_fraction: 0.95 network: - _target_: text_recognizer.networks.ConvTransformer + _target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [1, 2, 4, 8, 32] depths: [2, 3, 3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: &dim 512 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: *dim mult: 2 pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: *dim axial_shape: [18, 80] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder dim: *dim num_classes: 58 pad_index: *ignore_index decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *dim depth: 6 block: - _target_: "text_recognizer.networks.transformer.decoder_block.\ + _target_: "text_recognizer.network.transformer.decoder_block.\ DecoderBlock" self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *dim num_heads: 8 dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: *dim_head trainer: diff --git a/training/conf/experiment/vit_lines.yaml b/training/conf/experiment/vit_lines.yaml new file mode 100644 index 0000000..e2ddebf --- /dev/null +++ b/training/conf/experiment/vit_lines.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +defaults: + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: lit_transformer + - override /lr_scheduler: null + - override /optimizer: null + +tags: [lines, vit] +epochs: &epochs 64 +ignore_index: &ignore_index 3 +# summary: [[1, 1, 56, 1024], [1, 89]] + +logger: + wandb: + tags: ${tags} + +criterion: + ignore_index: *ignore_index + # label_smoothing: 0.05 + + +decoder: + max_output_len: 89 + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizer: + _target_: adan_pytorch.Adan + lr: 3.0e-4 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/cer + +datamodule: + batch_size: 8 + train_fraction: 0.95 + +network: + _target_: text_recognizer.network.vit.VisionTransformer + image_height: 56 + image_width: 1024 + patch_height: 28 + patch_width: 32 + dim: &dim 1024 + num_classes: &num_classes 58 + encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 2048 + heads: 16 + dim_head: 64 + depth: 4 + dropout_rate: 0.0 + token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true + pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true + tie_embeddings: false + pad_index: 3 + +model: + max_output_len: 89 + +trainer: + fast_dev_run: false + gradient_clip_val: 1.0 + max_epochs: *epochs + accumulate_grad_batches: 1 + limit_val_batches: .02 + limit_test_batches: .02 + limit_train_batches: 1.0 + # limit_val_batches: 1.0 + # limit_test_batches: 1.0 -- cgit v1.2.3-70-g09d2