summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml132
1 files changed, 64 insertions, 68 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index cdac387..ff931cc 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -10,8 +10,7 @@ defaults:
- override /optimizer: null
tags: [paragraphs]
-epochs: &epochs 600
-num_classes: &num_classes 58
+epochs: &epochs 256
ignore_index: &ignore_index 3
# max_output_len: &max_output_len 682
# summary: [[1, 1, 576, 640], [1, 682]]
@@ -54,83 +53,80 @@ lr_scheduler:
monitor: val/cer
datamodule:
- batch_size: 2
+ batch_size: 4
train_fraction: 0.95
network:
_target_: text_recognizer.networks.ConvTransformer
- input_dims: [1, 1, 576, 640]
- hidden_dim: &hidden_dim 128
- num_classes: *num_classes
- pad_index: 3
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
- dim: 16
- dim_mults: [1, 2, 4, 8, 8]
- depths: [3, 3, 3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 1], [2, 1], [2, 1]]
- attn:
- _target_: text_recognizer.networks.convnext.TransformerBlock
+ _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ encoder:
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [1, 2, 4, 8, 32]
+ depths: [2, 3, 3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
attn:
- _target_: text_recognizer.networks.convnext.Attention
- dim: 128
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.networks.convnext.FeedForward
- dim: 128
- mult: 4
+ _target_: text_recognizer.networks.convnext.TransformerBlock
+ attn:
+ _target_: text_recognizer.networks.convnext.Attention
+ dim: &dim 512
+ heads: 4
+ dim_head: 64
+ scale: 8
+ ff:
+ _target_: text_recognizer.networks.convnext.FeedForward
+ dim: *dim
+ mult: 2
+ pixel_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *dim
+ axial_shape: [18, 80]
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- depth: 6
- dim: *hidden_dim
- block:
- _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
- self_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
- dim: 64
- cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
- dim: *hidden_dim
- ff:
- _target_: text_recognizer.networks.transformer.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *hidden_dim
- axial_shape: [18, 160]
- axial_dims: [64, 64]
- token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *hidden_dim
- dropout_rate: 0.1
- max_len: 89
+ _target_: text_recognizer.networks.text_decoder.TextDecoder
+ dim: *dim
+ num_classes: 58
+ pad_index: *ignore_index
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: *dim
+ depth: 6
+ block:
+ _target_: "text_recognizer.networks.transformer.decoder_block.\
+ DecoderBlock"
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *dim
+ num_heads: 8
+ dim_head: &dim_head 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *dim
+ num_heads: 8
+ dim_head: *dim_head
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: *dim_head
trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 8
+ accumulate_grad_batches: 2
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0