From 84e11d0a0c4e494b6e8651189502fc040a0deaaf Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 1 Nov 2021 00:36:33 +0100 Subject: Update to config --- .../experiment/conv_transformer_paragraphs.yaml | 66 ++++++++++++---------- 1 file changed, 37 insertions(+), 29 deletions(-) (limited to 'training/conf/experiment') diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 8c3af44..5fb7377 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -47,10 +47,10 @@ optimizers: lr_schedulers: network: _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-4 + max_lr: 1.5e-4 total_steps: null epochs: *epochs - steps_per_epoch: 211 + steps_per_epoch: 722 pct_start: 0.03 anneal_strategy: cos cycle_momentum: true @@ -72,12 +72,23 @@ datamodule: pin_memory: true << : *mapping +rotary_embedding: &rotary_embedding + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 64 + +attn: &attn + dim: 192 + num_heads: 4 + dim_head: 64 + dropout_rate: 0.05 + network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: &hidden_dim 192 encoder_dim: 1280 - dropout_rate: 0.1 + dropout_rate: 0.05 num_classes: *num_classes pad_index: *ignore_index encoder: @@ -88,42 +99,39 @@ network: bn_momentum: 0.99 bn_eps: 1.0e-3 decoder: + depth: 4 + local_depth: 2 _target_: text_recognizer.networks.transformer.layers.Decoder - dim: *hidden_dim - depth: 3 - num_heads: 4 - attn_fn: text_recognizer.networks.transformer.attention.Attention - attn_kwargs: - dim_head: 32 - dropout_rate: 0.05 - local_attn_fn: text_recognizer.networks.transformer.local_attention.LocalAttention - local_attn_kwargs: - dim_head: 32 - dropout_rate: 0.05 + self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: true + << : *rotary_embedding + cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: false + local_self_attn: + _target_: text_recognizer.networks.transformer.local_attention.LocalAttention + << : *attn window_size: 11 look_back: 2 - depth: 2 - norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm - ff_fn: text_recognizer.networks.transformer.mlp.FeedForward - ff_kwargs: + << : *rotary_embedding + norm: + _target_: text_recognizer.networks.transformer.norm.ScaleNorm + normalized_shape: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward + dim: *hidden_dim dim_out: null expansion_factor: 4 glu: true dropout_rate: 0.05 - cross_attend: true pre_norm: true - rotary_emb: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding - dim: 32 pixel_pos_embedding: _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding - dim: *hidden_dim + dim: *hidden_dim shape: [18, 20] - token_pos_embedding: - _target_: text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding - hidden_dim: *hidden_dim - dropout_rate: 0.05 - max_len: *max_output_len model: _target_: text_recognizer.models.transformer.TransformerLitModel @@ -149,5 +157,5 @@ trainer: limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null - accumulate_grad_batches: 16 + accumulate_grad_batches: 7 overfit_batches: 0 -- cgit v1.2.3-70-g09d2