From 815b3fa220309a50c45ec15de0abfc47a0faeb20 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 Oct 2022 22:09:07 +0200 Subject: Update paragraph config --- .../experiment/conv_transformer_paragraphs.yaml | 132 ++++++++++----------- 1 file changed, 64 insertions(+), 68 deletions(-) (limited to 'training/conf/experiment') diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index cdac387..ff931cc 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -10,8 +10,7 @@ defaults: - override /optimizer: null tags: [paragraphs] -epochs: &epochs 600 -num_classes: &num_classes 58 +epochs: &epochs 256 ignore_index: &ignore_index 3 # max_output_len: &max_output_len 682 # summary: [[1, 1, 576, 640], [1, 682]] @@ -54,83 +53,80 @@ lr_scheduler: monitor: val/cer datamodule: - batch_size: 2 + batch_size: 4 train_fraction: 0.95 network: _target_: text_recognizer.networks.ConvTransformer - input_dims: [1, 1, 576, 640] - hidden_dim: &hidden_dim 128 - num_classes: *num_classes - pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [1, 2, 4, 8, 8] - depths: [3, 3, 3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 1], [2, 1], [2, 1]] - attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [1, 2, 4, 8, 32] + depths: [2, 3, 3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] attn: - _target_: text_recognizer.networks.convnext.Attention - dim: 128 - heads: 4 - dim_head: 64 - scale: 8 - ff: - _target_: text_recognizer.networks.convnext.FeedForward - dim: 128 - mult: 4 + _target_: text_recognizer.networks.convnext.TransformerBlock + attn: + _target_: text_recognizer.networks.convnext.Attention + dim: &dim 512 + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.networks.convnext.FeedForward + dim: *dim + mult: 2 + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *dim + axial_shape: [18, 80] decoder: - _target_: text_recognizer.networks.transformer.Decoder - depth: 6 - dim: *hidden_dim - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate - pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [18, 160] - axial_dims: [64, 64] - token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + dim: *dim + num_classes: 58 + pad_index: *ignore_index + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *dim + depth: 6 + block: + _target_: "text_recognizer.networks.transformer.decoder_block.\ + DecoderBlock" + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *dim + num_heads: 8 + dim_head: &dim_head 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *dim + num_heads: 8 + dim_head: *dim_head + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: *dim_head trainer: gradient_clip_val: 1.0 max_epochs: *epochs - accumulate_grad_batches: 8 + accumulate_grad_batches: 2 limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 -- cgit v1.2.3-70-g09d2