diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/experiment/conv_transformer_paragraphs.yaml | 13 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 4 |
2 files changed, 8 insertions, 9 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml index 7c6e231..c8db485 100644 --- a/training/conf/experiment/conv_transformer_paragraphs.yaml +++ b/training/conf/experiment/conv_transformer_paragraphs.yaml @@ -34,15 +34,14 @@ optimizer: betas: [0.9, 0.999] weight_decay: 0 eps: 1.0e-8 - parameters: network lr_scheduler: _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 3.0e-4 total_steps: null epochs: *epochs - steps_per_epoch: 3201 - pct_start: 0.3 + steps_per_epoch: 5037 + pct_start: 0.15 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 @@ -56,7 +55,7 @@ lr_scheduler: monitor: val/cer datamodule: - batch_size: 6 + batch_size: 4 train_fraction: 0.95 network: @@ -66,9 +65,9 @@ network: encoder: depth: 5 decoder: - depth: 6 + depth: 4 pixel_embedding: - shape: [18, 78] + shape: [17, 79] model: max_output_len: *max_output_len @@ -76,4 +75,4 @@ model: trainer: gradient_clip_val: 0.5 max_epochs: *epochs - accumulate_grad_batches: 1 + accumulate_grad_batches: 2 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index ccdf960..90c2cb8 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -9,7 +9,7 @@ encoder: stochastic_dropout_rate: 0.2 bn_momentum: 0.99 bn_eps: 1.0e-3 - depth: 3 + depth: 5 out_channels: *hidden_dim stride: [2, 1] decoder: @@ -47,4 +47,4 @@ decoder: pixel_embedding: _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding dim: *hidden_dim - shape: [17, 78] + shape: [18, 78] |