diff options
3 files changed, 193 insertions, 22 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml new file mode 100644 index 0000000..d2a666f --- /dev/null +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -0,0 +1,151 @@ +# @package _global_ + +defaults: + - override /mapping: null + - override /criterion: null + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: null + - override /lr_schedulers: null + - override /optimizers: null + +epochs: &epochs 512 +ignore_index: &ignore_index 3 +num_classes: &num_classes 58 +max_output_len: &max_output_len 89 +summary: [[1, 1, 56, 1024], [1, 89]] + +criterion: + _target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss + smoothing: 0.1 + ignore_index: *ignore_index + +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.emnist.EmnistMapping + # extra_symbols: [ "\n" ] + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizers: + madgrad: + _target_: madgrad.MADGRAD + lr: 1.0e-4 + momentum: 0.9 + weight_decay: 5.0e-6 + eps: 1.0e-6 + parameters: network + +lr_schedulers: + network: + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 1.0e-4 + total_steps: null + epochs: *epochs + steps_per_epoch: 722 + pct_start: 0.01 + anneal_strategy: cos + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.95 + div_factor: 25 + final_div_factor: 1.0e2 + three_phase: false + last_epoch: -1 + verbose: false + interval: step + monitor: val/loss + +datamodule: + batch_size: 32 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping + +rotary_embedding: &rotary_embedding + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 64 + +attn: &attn + dim: 192 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + +network: + _target_: text_recognizer.networks.conv_transformer.ConvTransformer + input_dims: [1, 56, 1024] + hidden_dim: &hidden_dim 192 + num_classes: *num_classes + pad_index: *ignore_index + encoder: + _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet + arch: b0 + out_channels: 1280 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + decoder: + depth: 4 + _target_: text_recognizer.networks.transformer.layers.Decoder + self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: true + << : *rotary_embedding + cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: false + norm: + _target_: text_recognizer.networks.transformer.norm.ScaleNorm + normalized_shape: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 4 + glu: true + dropout_rate: 0.05 + pre_norm: true + pixel_pos_embedding: + _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding + dim: *hidden_dim + shape: [1, 32] + +model: + _target_: text_recognizer.models.transformer.TransformerLitModel + << : *mapping + max_output_len: *max_output_len + start_token: <s> + end_token: <e> + pad_token: <p> + +trainer: + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: true + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0.5 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 1 + overfit_batches: 0 diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 5fb7377..e958367 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -47,11 +47,11 @@ optimizers: lr_schedulers: network: _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.5e-4 + max_lr: 1.0e-4 total_steps: null epochs: *epochs steps_per_epoch: 722 - pct_start: 0.03 + pct_start: 0.01 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 @@ -87,8 +87,6 @@ network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: &hidden_dim 192 - encoder_dim: 1280 - dropout_rate: 0.05 num_classes: *num_classes pad_index: *ignore_index encoder: @@ -99,7 +97,7 @@ network: bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: - depth: 4 + depth: 3 local_depth: 2 _target_: text_recognizer.networks.transformer.layers.Decoder self_attn: @@ -114,8 +112,9 @@ network: local_self_attn: _target_: text_recognizer.networks.transformer.local_attention.LocalAttention << : *attn - window_size: 11 - look_back: 2 + window_size: 31 + look_back: 1 + autopad: true << : *rotary_embedding norm: _target_: text_recognizer.networks.transformer.norm.ScaleNorm diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index bc0678b..7dced16 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -1,21 +1,42 @@ -defaults: - - rotary_emb: null - _target_: text_recognizer.networks.transformer.Decoder -dim: 128 -depth: 2 -num_heads: 4 -attn_fn: text_recognizer.networks.transformer.attention.Attention -attn_kwargs: +depth: 4 +pre_norm: true +local_depth: 2 +has_pos_emb: true +self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 dim_head: 64 - dropout_rate: 0.2 -norm_fn: torch.nn.LayerNorm -ff_fn: text_recognizer.networks.transformer.mlp.FeedForward -ff_kwargs: + dropout_rate: 0.05 + causal: true + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 128 +local_self_attn: + _target_: text_recognizer.networks.transformer.local_attention.LocalAttention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + window_size: 22 + look_back: 1 + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 128 +cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + dim: 64 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + causal: false +norm: + _target_: text_recognizer.networks.transformer.norm.ScaleNorm + normalized_shape: 192 +ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward dim_out: null expansion_factor: 4 glu: true dropout_rate: 0.2 -cross_attend: true -pre_norm: true -rotary_emb: null |