diff options
Diffstat (limited to 'training/conf/experiment')
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index d32c7d6..4e921f2 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -14,6 +14,7 @@ epochs: &epochs 260 ignore_index: &ignore_index 3 num_classes: &num_classes 58 max_output_len: &max_output_len 89 +dim: &dim 384 # summary: [[1, 1, 56, 1024], [1, 89]] logger: @@ -71,13 +72,13 @@ network: _target_: text_recognizer.networks.convnext.TransformerBlock attn: _target_: text_recognizer.networks.convnext.Attention - dim: *hidden_dim + dim: *dim heads: 4 dim_head: 64 scale: 8 ff: _target_: text_recognizer.networks.convnext.FeedForward - dim: *hidden_dim + dim: *dim mult: 2 pixel_embedding: _target_: "text_recognizer.networks.transformer.embeddings.axial.\ @@ -87,36 +88,36 @@ network: axial_dims: [192, 192] decoder: _target_: text_recognizer.networks.text_decoder.TextDecoder - hidden_dim: *hidden_dim + hidden_dim: *dim num_classes: *num_classes pad_index: *ignore_index decoder: _target_: text_recognizer.networks.transformer.Decoder - dim: *hidden_dim + dim: *dim depth: 6 block: _target_: text_recognizer.networks.transformer.decoder_block.\ DecoderBlock self_attn: _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim + dim: *dim num_heads: 10 dim_head: 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim + dim: *dim num_heads: 10 dim_head: 64 dropout_rate: *dropout_rate causal: false norm: _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim + dim: *dim ff: _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim + dim: *dim dim_out: null expansion_factor: 2 glu: true @@ -127,7 +128,7 @@ network: token_pos_embedding: _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ PositionalEncoding" - dim: *hidden_dim + dim: *dim dropout_rate: 0.1 max_len: *max_output_len |