summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_paragraphs.yaml
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/experiment/conv_transformer_paragraphs.yaml')
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml66
1 files changed, 37 insertions, 29 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 8c3af44..5fb7377 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -47,10 +47,10 @@ optimizers:
lr_schedulers:
network:
_target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 1.0e-4
+ max_lr: 1.5e-4
total_steps: null
epochs: *epochs
- steps_per_epoch: 211
+ steps_per_epoch: 722
pct_start: 0.03
anneal_strategy: cos
cycle_momentum: true
@@ -72,12 +72,23 @@ datamodule:
pin_memory: true
<< : *mapping
+rotary_embedding: &rotary_embedding
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
+ dim: 64
+
+attn: &attn
+ dim: 192
+ num_heads: 4
+ dim_head: 64
+ dropout_rate: 0.05
+
network:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: &hidden_dim 192
encoder_dim: 1280
- dropout_rate: 0.1
+ dropout_rate: 0.05
num_classes: *num_classes
pad_index: *ignore_index
encoder:
@@ -88,42 +99,39 @@ network:
bn_momentum: 0.99
bn_eps: 1.0e-3
decoder:
+ depth: 4
+ local_depth: 2
_target_: text_recognizer.networks.transformer.layers.Decoder
- dim: *hidden_dim
- depth: 3
- num_heads: 4
- attn_fn: text_recognizer.networks.transformer.attention.Attention
- attn_kwargs:
- dim_head: 32
- dropout_rate: 0.05
- local_attn_fn: text_recognizer.networks.transformer.local_attention.LocalAttention
- local_attn_kwargs:
- dim_head: 32
- dropout_rate: 0.05
+ self_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ << : *attn
+ causal: true
+ << : *rotary_embedding
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.attention.Attention
+ << : *attn
+ causal: false
+ local_self_attn:
+ _target_: text_recognizer.networks.transformer.local_attention.LocalAttention
+ << : *attn
window_size: 11
look_back: 2
- depth: 2
- norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
- ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
- ff_kwargs:
+ << : *rotary_embedding
+ norm:
+ _target_: text_recognizer.networks.transformer.norm.ScaleNorm
+ normalized_shape: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.mlp.FeedForward
+ dim: *hidden_dim
dim_out: null
expansion_factor: 4
glu: true
dropout_rate: 0.05
- cross_attend: true
pre_norm: true
- rotary_emb:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
- dim: 32
pixel_pos_embedding:
_target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
- dim: *hidden_dim
+ dim: *hidden_dim
shape: [18, 20]
- token_pos_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding
- hidden_dim: *hidden_dim
- dropout_rate: 0.05
- max_len: *max_output_len
model:
_target_: text_recognizer.models.transformer.TransformerLitModel
@@ -149,5 +157,5 @@ trainer:
limit_val_batches: 1.0
limit_test_batches: 1.0
resume_from_checkpoint: null
- accumulate_grad_batches: 16
+ accumulate_grad_batches: 7
overfit_batches: 0