diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 78 | ||||
-rw-r--r-- | training/conf/network/convnext.yaml | 12 |
2 files changed, 80 insertions, 10 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 4bd3b45..34eedab 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -4,7 +4,7 @@ defaults: - override /criterion: cross_entropy - override /callbacks: htr - override /datamodule: iam_extended_paragraphs - - override /network: conv_transformer + - override /network: null - override /model: lit_transformer - override /lr_scheduler: null - override /optimizer: null @@ -35,11 +35,10 @@ callbacks: device: null optimizer: - _target_: torch.optim.RAdam + _target_: adan_pytorch.Adan lr: 3.0e-4 - betas: [0.9, 0.999] - weight_decay: 0 - eps: 1.0e-8 + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 lr_scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau @@ -60,14 +59,73 @@ datamodule: train_fraction: 0.95 network: + _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 576, 640] + hidden_dim: &hidden_dim 128 num_classes: *num_classes - pad_index: *ignore_index + pad_index: 3 + encoder: + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8, 8, 8] + depths: [3, 3, 3, 4, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]] + attn: + _target_: text_recognizer.networks.convnext.TransformerBlock + attn: + _target_: text_recognizer.networks.convnext.Attention + dim: 128 + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.networks.convnext.FeedForward + dim: 128 + mult: 4 + decoder: + _target_: text_recognizer.networks.transformer.Decoder + depth: 6 + block: + _target_: text_recognizer.networks.transformer.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate pixel_embedding: - shape: [18, 79] - -model: - max_output_len: *max_output_len + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *hidden_dim + axial_shape: [7, 128] + axial_dims: [64, 64] + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 trainer: gradient_clip_val: 1.0 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index bc1ce93..63ad424 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -3,3 +3,15 @@ dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] +attn: + _target_: text_recognizer.networks.convnext.TransformerBlock + attn: + _target_: text_recognizer.networks.convnext.Attention + dim: 128 + heads: 4 + dim_head: 64 + scale: 8 + ff: + _target_: text_recognizer.networks.convnext.FeedForward + dim: 128 + mult: 4 |