diff options
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 30 | ||||
-rw-r--r-- | training/conf/network/convnext.yaml | 5 |
2 files changed, 23 insertions, 12 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f618ba1..c71296b 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,19 +1,17 @@ _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 144 +hidden_dim: &hidden_dim 128 num_classes: 58 pad_index: 3 encoder: - _target_: text_recognizer.networks.EfficientNet - arch: b0 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - depth: 5 - out_channels: *hidden_dim + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] decoder: _target_: text_recognizer.networks.transformer.Decoder - depth: 6 + depth: 10 block: _target_: text_recognizer.networks.transformer.DecoderBlock self_attn: @@ -29,7 +27,7 @@ decoder: cross_attn: _target_: text_recognizer.networks.transformer.Attention dim: *hidden_dim - num_heads: 8 + num_heads: 12 dim_head: 64 dropout_rate: *dropout_rate causal: false @@ -44,6 +42,14 @@ decoder: glu: true dropout_rate: *dropout_rate pixel_embedding: - _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *hidden_dim + axial_shape: [7, 128] + axial_dims: [64, 64] +token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" dim: *hidden_dim - shape: [18, 79] + dropout_rate: 0.1 + max_len: 89 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml new file mode 100644 index 0000000..bc1ce93 --- /dev/null +++ b/training/conf/network/convnext.yaml @@ -0,0 +1,5 @@ +_target_: text_recognizer.networks.convnext.ConvNext +dim: 16 +dim_mults: [2, 4, 8] +depths: [3, 3, 6] +downsampling_factors: [[2, 2], [2, 2], [2, 2]] |