diff options
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 60 |
1 files changed, 49 insertions, 11 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 1d61129..54eb028 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,11 +1,49 @@ -defaults: - - encoder: efficientnet - - decoder: transformer_decoder - -_target_: text_recognizer.networks.conv_transformer.ConvTransformer -input_dims: [1, 576, 640] -hidden_dim: 128 -encoder_dim: 1280 -dropout_rate: 0.2 -num_classes: 1006 -pad_index: 1002 +_target_: text_recognizer.networks.ConvTransformer +input_dims: [1, 1, 576, 640] +hidden_dim: &hidden_dim 144 +num_classes: 58 +pad_index: 3 +encoder: + _target_: text_recognizer.networks.EfficientNet + arch: b0 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + depth: 3 + out_channels: 128 +decoder: + _target_: text_recognizer.networks.transformer.Decoder + depth: 6 + block: + _target_: text_recognizer.networks.transformer.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 8 + dim_head: 64 + dropout_rate: &dropout_rate 0.4 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 8 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate +pixel_embedding: + _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + dim: *hidden_dim + shape: [72, 80] |