From 8fe4b36bf22281c84c4afee811b3435f3b50686d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 11 Jun 2022 23:10:56 +0200 Subject: Update configs --- .../conf/network/decoder/transformer_decoder.yaml | 70 +++++++++------------- 1 file changed, 29 insertions(+), 41 deletions(-) (limited to 'training/conf/network/decoder') diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 7dced16..4588ee9 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -1,42 +1,30 @@ -_target_: text_recognizer.networks.transformer.Decoder +_target_: text_recognizer.networks.transformer.decoder.Decoder depth: 4 -pre_norm: true -local_depth: 2 -has_pos_emb: true -self_attn: - _target_: text_recognizer.networks.transformer.attention.Attention - dim: 64 - num_heads: 4 - dim_head: 64 - dropout_rate: 0.05 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding - dim: 128 -local_self_attn: - _target_: text_recognizer.networks.transformer.local_attention.LocalAttention - dim: 64 - num_heads: 4 - dim_head: 64 - dropout_rate: 0.05 - window_size: 22 - look_back: 1 - rotary_embedding: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding - dim: 128 -cross_attn: - _target_: text_recognizer.networks.transformer.attention.Attention - dim: 64 - num_heads: 4 - dim_head: 64 - dropout_rate: 0.05 - causal: false -norm: - _target_: text_recognizer.networks.transformer.norm.ScaleNorm - normalized_shape: 192 -ff: - _target_: text_recognizer.networks.transformer.mlp.FeedForward - dim_out: null - expansion_factor: 4 - glu: true - dropout_rate: 0.2 +block: + _target_: text_recognizer.networks.transformer.decoder.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 128 + cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + causal: false + norm: + _target_: text_recognizer.networks.transformer.norm.RMSNorm + normalized_shape: 192 + ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward + dim_out: null + expansion_factor: 4 + glu: true + dropout_rate: 0.2 -- cgit v1.2.3-70-g09d2