From bace9540792d517c1687d91f455f9503d9854609 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 30 Sep 2022 00:35:34 +0200 Subject: Update conv transformer model --- .../conf/experiment/conv_transformer_lines.yaml | 136 +++++++++++---------- training/conf/network/conv_transformer.yaml | 109 +++++++++-------- 2 files changed, 126 insertions(+), 119 deletions(-) (limited to 'training') diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index e0e426c..d32c7d6 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -12,7 +12,7 @@ defaults: tags: [lines] epochs: &epochs 260 ignore_index: &ignore_index 3 -num_classes: &num_classes 57 +num_classes: &num_classes 58 max_output_len: &max_output_len 89 # summary: [[1, 1, 56, 1024], [1, 89]] @@ -35,7 +35,7 @@ callbacks: optimizer: _target_: adan_pytorch.Adan - lr: 1.0e-3 + lr: 3.0e-4 betas: [0.02, 0.08, 0.01] weight_decay: 0.02 @@ -59,73 +59,77 @@ datamodule: network: _target_: text_recognizer.networks.ConvTransformer - input_dims: [1, 1, 56, 1024] - hidden_dim: &hidden_dim 384 - num_classes: 58 - pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 24] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] - attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 24] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.Attention - dim: *hidden_dim - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.networks.convnext.FeedForward - dim: *hidden_dim - mult: 2 + _target_: text_recognizer.networks.convnext.TransformerBlock + attn: + _target_: text_recognizer.networks.convnext.Attention + dim: *hidden_dim + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.networks.convnext.FeedForward + dim: *hidden_dim + mult: 2 + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: &hidden_dim 384 + axial_shape: [7, 128] + axial_dims: [192, 192] decoder: - _target_: text_recognizer.networks.transformer.Decoder - depth: 6 - dim: *hidden_dim - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [7, 128] - axial_dims: [192, 192] - token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + hidden_dim: *hidden_dim + num_classes: *num_classes + pad_index: *ignore_index + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *hidden_dim + depth: 6 + block: + _target_: text_recognizer.networks.transformer.decoder_block.\ + DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 10 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 10 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: *max_output_len model: max_output_len: *max_output_len diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 0ef862f..016adbb 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,56 +1,59 @@ _target_: text_recognizer.networks.ConvTransformer -input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 128 -num_classes: 58 -pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 8] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: &hidden_dim 128 + axial_shape: [7, 128] + axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.transformer.Decoder - dim: *hidden_dim - depth: 10 - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate -pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [7, 128] - axial_dims: [64, 64] -token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + hidden_dim: *hidden_dim + num_classes: 58 + pad_index: 3 + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *hidden_dim + depth: 10 + block: + _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 -- cgit v1.2.3-70-g09d2