summaryrefslogtreecommitdiff
path: root/training/conf/network/decoder/transformer_decoder.yaml
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/network/decoder/transformer_decoder.yaml')
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml70
1 files changed, 29 insertions, 41 deletions
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 7dced16..4588ee9 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -1,42 +1,30 @@
-_target_: text_recognizer.networks.transformer.Decoder
+_target_: text_recognizer.networks.transformer.decoder.Decoder
depth: 4
-pre_norm: true
-local_depth: 2
-has_pos_emb: true
-self_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- dim: 64
- num_heads: 4
- dim_head: 64
- dropout_rate: 0.05
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
- dim: 128
-local_self_attn:
- _target_: text_recognizer.networks.transformer.local_attention.LocalAttention
- dim: 64
- num_heads: 4
- dim_head: 64
- dropout_rate: 0.05
- window_size: 22
- look_back: 1
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
- dim: 128
-cross_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- dim: 64
- num_heads: 4
- dim_head: 64
- dropout_rate: 0.05
- causal: false
-norm:
- _target_: text_recognizer.networks.transformer.norm.ScaleNorm
- normalized_shape: 192
-ff:
- _target_: text_recognizer.networks.transformer.mlp.FeedForward
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
+block:
+ _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ dim: 64
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: 0.05
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ dim: 128
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ dim: 64
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: 0.05
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.norm.RMSNorm
+ normalized_shape: 192
+ ff:
+ _target_: text_recognizer.networks.transformer.mlp.FeedForward
+ dim_out: null
+ expansion_factor: 4
+ glu: true
+ dropout_rate: 0.2