diff options
Diffstat (limited to 'training/conf/experiment/conv_transformer_lines.yaml')
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 3392cd6..3f5da86 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -10,11 +10,8 @@ defaults: - override /optimizer: null tags: [lines] -epochs: &epochs 260 +epochs: &epochs 64 ignore_index: &ignore_index 3 -num_classes: &num_classes 58 -max_output_len: &max_output_len 89 -dim: &dim 384 # summary: [[1, 1, 56, 1024], [1, 89]] logger: @@ -57,6 +54,9 @@ lr_scheduler: datamodule: batch_size: 16 train_fraction: 0.95 + transform: + _target_: text_recognizer.data.stems.line.IamLinesStem + augment: false network: _target_: text_recognizer.networks.ConvTransformer @@ -65,14 +65,14 @@ network: encoder: _target_: text_recognizer.networks.convnext.ConvNext dim: 16 - dim_mults: [2, 4, 24] + dim_mults: [2, 4, 32] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: _target_: text_recognizer.networks.convnext.TransformerBlock attn: _target_: text_recognizer.networks.convnext.Attention - dim: *dim + dim: &dim 512 heads: 4 dim_head: 64 scale: 8 @@ -85,11 +85,10 @@ network: AxialPositionalEmbeddingImage" dim: *dim axial_shape: [7, 128] - axial_dims: [192, 192] decoder: _target_: text_recognizer.networks.text_decoder.TextDecoder - hidden_dim: *dim - num_classes: *num_classes + dim: *dim + num_classes: 58 pad_index: *ignore_index decoder: _target_: text_recognizer.networks.transformer.Decoder @@ -102,14 +101,14 @@ network: _target_: text_recognizer.networks.transformer.Attention dim: *dim num_heads: 8 - dim_head: 64 + dim_head: &dim_head 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: _target_: text_recognizer.networks.transformer.Attention dim: *dim num_heads: 8 - dim_head: 64 + dim_head: *dim_head dropout_rate: *dropout_rate causal: false norm: @@ -124,10 +123,10 @@ network: dropout_rate: *dropout_rate rotary_embedding: _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 + dim: *dim_head model: - max_output_len: *max_output_len + max_output_len: 89 trainer: gradient_clip_val: 1.0 |