summaryrefslogtreecommitdiff
path: root/training/conf/network/decoder/transformer_decoder.yaml
blob: bc0678b9da29b3e09810b5a4e516d1f6ade1b099 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
defaults:
  - rotary_emb: null

_target_: text_recognizer.networks.transformer.Decoder
dim: 128
depth: 2
num_heads: 4
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
  dim_head: 64
  dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
ff_kwargs:
  dim_out: null
  expansion_factor: 4
  glu: true
  dropout_rate: 0.2
cross_attend: true
pre_norm: true
rotary_emb: null