summaryrefslogtreecommitdiff
path: root/training/conf/network
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:18:13 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-03 22:18:13 +0100
commit0879c06112e11c2091c223575573e40f086e38d5 (patch)
tree2384ff23f01f6ecbaf62186f382a1dbe238f9e88 /training/conf/network
parent6763b70e8543303bdda73fd2a0243b7e02117729 (diff)
Update configs
Diffstat (limited to 'training/conf/network')
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml51
1 files changed, 36 insertions, 15 deletions
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index bc0678b..7dced16 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -1,21 +1,42 @@
-defaults:
- - rotary_emb: null
-
_target_: text_recognizer.networks.transformer.Decoder
-dim: 128
-depth: 2
-num_heads: 4
-attn_fn: text_recognizer.networks.transformer.attention.Attention
-attn_kwargs:
+depth: 4
+pre_norm: true
+local_depth: 2
+has_pos_emb: true
+self_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ dim: 64
+ num_heads: 4
dim_head: 64
- dropout_rate: 0.2
-norm_fn: torch.nn.LayerNorm
-ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
-ff_kwargs:
+ dropout_rate: 0.05
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ dim: 128
+local_self_attn:
+ _target_: text_recognizer.networks.transformer.local_attention.LocalAttention
+ dim: 64
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: 0.05
+ window_size: 22
+ look_back: 1
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ dim: 128
+cross_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ dim: 64
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: 0.05
+ causal: false
+norm:
+ _target_: text_recognizer.networks.transformer.norm.ScaleNorm
+ normalized_shape: 192
+ff:
+ _target_: text_recognizer.networks.transformer.mlp.FeedForward
dim_out: null
expansion_factor: 4
glu: true
dropout_rate: 0.2
-cross_attend: true
-pre_norm: true
-rotary_emb: null