From 0879c06112e11c2091c223575573e40f086e38d5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 3 Nov 2021 22:18:13 +0100 Subject: Update configs --- .../conf/network/decoder/transformer_decoder.yaml | 51 +++++++++++++++------- 1 file changed, 36 insertions(+), 15 deletions(-) (limited to 'training/conf/network/decoder') diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index bc0678b..7dced16 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -1,21 +1,42 @@ -defaults: - - rotary_emb: null - _target_: text_recognizer.networks.transformer.Decoder -dim: 128 -depth: 2 -num_heads: 4 -attn_fn: text_recognizer.networks.transformer.attention.Attention -attn_kwargs: +depth: 4 +pre_norm: true +local_depth: 2 +has_pos_emb: true +self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 dim_head: 64 - dropout_rate: 0.2 -norm_fn: torch.nn.LayerNorm -ff_fn: text_recognizer.networks.transformer.mlp.FeedForward -ff_kwargs: + dropout_rate: 0.05 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 128 +local_self_attn: + _target_: text_recognizer.networks.transformer.local_attention.LocalAttention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + window_size: 22 + look_back: 1 + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 128 +cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + causal: false +norm: + _target_: text_recognizer.networks.transformer.norm.ScaleNorm + normalized_shape: 192 +ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward dim_out: null expansion_factor: 4 glu: true dropout_rate: 0.2 -cross_attend: true -pre_norm: true -rotary_emb: null -- cgit v1.2.3-70-g09d2