From 7268035fb9e57342612a8cc50a1fe04e8841ca2f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 30 Jul 2021 23:15:03 +0200 Subject: attr bug fix, properly loading network --- training/conf/network/conv_transformer.yaml | 8 +++----- training/conf/network/decoder/transformer_decoder.yaml | 7 +++---- 2 files changed, 6 insertions(+), 9 deletions(-) (limited to 'training/conf/network') diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f72e030..7d57a2d 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -6,8 +6,6 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: 256 dropout_rate: 0.2 -max_output_len: 682 -num_classes: 1004 -start_token: -end_token: -pad_token:

+max_output_len: 451 +num_classes: 1006 +pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 60c5762..3122de1 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -1,21 +1,20 @@ +defaults: + - rotary_emb: null + _target_: text_recognizer.networks.transformer.Decoder dim: 256 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - num_heads: 8 dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward ff_kwargs: - dim: 256 dim_out: null expansion_factor: 4 glu: true dropout_rate: 0.2 -rotary_emb: null -rotary_emb_dim: null cross_attend: true pre_norm: true -- cgit v1.2.3-70-g09d2