diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:35:34 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:35:34 +0200 |
commit | bace9540792d517c1687d91f455f9503d9854609 (patch) | |
tree | 38aff5e65a9f7410930aa8a6a92bf7ecd7c89bd5 /training/conf/network | |
parent | 9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (diff) |
Update conv transformer model
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 109 |
1 files changed, 56 insertions, 53 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 0ef862f..016adbb 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,56 +1,59 @@ _target_: text_recognizer.networks.ConvTransformer -input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 128 -num_classes: 58 -pad_index: 3 encoder: - _target_: text_recognizer.networks.convnext.ConvNext - dim: 16 - dim_mults: [2, 4, 8] - depths: [3, 3, 6] - downsampling_factors: [[2, 2], [2, 2], [2, 2]] + _target_: text_recognizer.networks.image_encoder.ImageEncoder + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] + pixel_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: &hidden_dim 128 + axial_shape: [7, 128] + axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.transformer.Decoder - dim: *hidden_dim - depth: 10 - block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: &dropout_rate 0.2 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 12 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate -pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ - AxialPositionalEmbeddingImage" - dim: *hidden_dim - axial_shape: [7, 128] - axial_dims: [64, 64] -token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ - PositionalEncoding" - dim: *hidden_dim - dropout_rate: 0.1 - max_len: 89 + _target_: text_recognizer.networks.text_decoder.TextDecoder + hidden_dim: *hidden_dim + num_classes: 58 + pad_index: 3 + decoder: + _target_: text_recognizer.networks.transformer.Decoder + dim: *hidden_dim + depth: 10 + block: + _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 |