summaryrefslogtreecommitdiff
path: root/training/conf/network/decoder/transformer_decoder.yaml
blob: 4588ee9a3023bc4f7e78bea6a7e1fbe3c29cca98 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
_target_: text_recognizer.networks.transformer.decoder.Decoder
depth: 4
block:
  _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
  self_attn:
    _target_: text_recognizer.networks.transformer.attention.Attention
    dim: 64
    num_heads: 4
    dim_head: 64
    dropout_rate: 0.05
    causal: true
    rotary_embedding:
      _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
      dim: 128
  cross_attn:
    _target_: text_recognizer.networks.transformer.attention.Attention
    dim: 64
    num_heads: 4
    dim_head: 64
    dropout_rate: 0.05
    causal: false
  norm:
    _target_: text_recognizer.networks.transformer.norm.RMSNorm
    normalized_shape: 192
  ff:
    _target_: text_recognizer.networks.transformer.mlp.FeedForward
    dim_out: null
    expansion_factor: 4
    glu: true
    dropout_rate: 0.2