diff options
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 101 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 7 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 30 | ||||
-rw-r--r-- | training/conf/network/convnext.yaml | 5 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 3 |
5 files changed, 98 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index eb9bc9e..d4478cc 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -4,20 +4,26 @@ defaults: - override /criterion: cross_entropy - override /callbacks: htr - override /datamodule: iam_lines - - override /network: conv_transformer + - override /network: null + # - override /network: conv_transformer - override /model: lit_transformer - override /lr_scheduler: null - override /optimizer: null -epochs: &epochs 512 +tags: [lines] +epochs: &epochs 260 ignore_index: &ignore_index 3 num_classes: &num_classes 57 max_output_len: &max_output_len 89 summary: [[1, 1, 56, 1024], [1, 89]] +logger: + wandb: + tags: ${tags} + criterion: ignore_index: *ignore_index - label_smoothing: 0.05 + # label_smoothing: 0.1 callbacks: stochastic_weight_averaging: @@ -29,30 +35,23 @@ callbacks: device: null optimizer: - _target_: torch.optim.RAdam + _target_: adan_pytorch.Adan lr: 3.0e-4 - betas: [0.9, 0.999] - weight_decay: 0 - eps: 1.0e-8 - parameters: network + betas: [0.02, 0.08, 0.01] + weight_decay: 0.02 lr_scheduler: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 3.0e-4 - total_steps: null - epochs: *epochs - steps_per_epoch: 1354 - pct_start: 0.3 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25.0 - final_div_factor: 10000.0 - three_phase: true - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 verbose: false - interval: step + interval: epoch monitor: val/cer datamodule: @@ -60,20 +59,66 @@ datamodule: train_fraction: 0.95 network: + _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 56, 1024] - num_classes: *num_classes - pad_index: *ignore_index + hidden_dim: &hidden_dim 128 + num_classes: 58 + pad_index: 3 encoder: - depth: 5 + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] decoder: + _target_: text_recognizer.networks.transformer.Decoder depth: 6 + block: + _target_: text_recognizer.networks.transformer.DecoderBlock + self_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: &dropout_rate 0.2 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.RotaryEmbedding + dim: 64 + cross_attn: + _target_: text_recognizer.networks.transformer.Attention + dim: *hidden_dim + num_heads: 12 + dim_head: 64 + dropout_rate: *dropout_rate + causal: false + norm: + _target_: text_recognizer.networks.transformer.RMSNorm + dim: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 2 + glu: true + dropout_rate: *dropout_rate pixel_embedding: - shape: [3, 64] + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *hidden_dim + axial_shape: [7, 128] + axial_dims: [64, 64] + token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" + dim: *hidden_dim + dropout_rate: 0.1 + max_len: 89 model: max_output_len: *max_output_len trainer: - gradient_clip_val: 0.5 + gradient_clip_val: 1.0 max_epochs: *epochs accumulate_grad_batches: 1 diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 41c236d..4bd3b45 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -19,10 +19,11 @@ summary: [[1, 1, 576, 640], [1, 682]] logger: wandb: tags: ${tags} + id: 8je5lxmx criterion: ignore_index: *ignore_index - label_smoothing: 0.05 + # label_smoothing: 0.05 callbacks: stochastic_weight_averaging: @@ -62,10 +63,6 @@ network: input_dims: [1, 1, 576, 640] num_classes: *num_classes pad_index: *ignore_index - encoder: - depth: 4 - decoder: - depth: 6 pixel_embedding: shape: [18, 79] diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f618ba1..c71296b 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,19 +1,17 @@ _target_: text_recognizer.networks.ConvTransformer input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 144 +hidden_dim: &hidden_dim 128 num_classes: 58 pad_index: 3 encoder: - _target_: text_recognizer.networks.EfficientNet - arch: b0 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - depth: 5 - out_channels: *hidden_dim + _target_: text_recognizer.networks.convnext.ConvNext + dim: 16 + dim_mults: [2, 4, 8] + depths: [3, 3, 6] + downsampling_factors: [[2, 2], [2, 2], [2, 2]] decoder: _target_: text_recognizer.networks.transformer.Decoder - depth: 6 + depth: 10 block: _target_: text_recognizer.networks.transformer.DecoderBlock self_attn: @@ -29,7 +27,7 @@ decoder: cross_attn: _target_: text_recognizer.networks.transformer.Attention dim: *hidden_dim - num_heads: 8 + num_heads: 12 dim_head: 64 dropout_rate: *dropout_rate causal: false @@ -44,6 +42,14 @@ decoder: glu: true dropout_rate: *dropout_rate pixel_embedding: - _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + AxialPositionalEmbeddingImage" + dim: *hidden_dim + axial_shape: [7, 128] + axial_dims: [64, 64] +token_pos_embedding: + _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + PositionalEncoding" dim: *hidden_dim - shape: [18, 79] + dropout_rate: 0.1 + max_len: 89 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml new file mode 100644 index 0000000..bc1ce93 --- /dev/null +++ b/training/conf/network/convnext.yaml @@ -0,0 +1,5 @@ +_target_: text_recognizer.networks.convnext.ConvNext +dim: 16 +dim_mults: [2, 4, 8] +depths: [3, 3, 6] +downsampling_factors: [[2, 2], [2, 2], [2, 2]] diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index c2d0d62..6112cd8 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -1,5 +1,4 @@ _target_: pytorch_lightning.Trainer -stochastic_weight_avg: true auto_scale_batch_size: binsearch auto_lr_find: false gradient_clip_val: 0.5 @@ -7,8 +6,6 @@ fast_dev_run: false gpus: 1 precision: 16 max_epochs: 256 -terminate_on_nan: true -weights_summary: null limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 |