diff options
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 8 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 7 |
2 files changed, 6 insertions, 9 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f72e030..7d57a2d 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -6,8 +6,6 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: 256 dropout_rate: 0.2 -max_output_len: 682 -num_classes: 1004 -start_token: <s> -end_token: <e> -pad_token: <p> +max_output_len: 451 +num_classes: 1006 +pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 60c5762..3122de1 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -1,21 +1,20 @@ +defaults: + - rotary_emb: null + _target_: text_recognizer.networks.transformer.Decoder dim: 256 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - num_heads: 8 dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward ff_kwargs: - dim: 256 dim_out: null expansion_factor: 4 glu: true dropout_rate: 0.2 -rotary_emb: null -rotary_emb_dim: null cross_attend: true pre_norm: true |