diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:12:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-11 22:12:00 +0200 |
commit | 42e5da95b31ac35376f483fe1816031c9339ec5c (patch) | |
tree | b3148fe0d0cf69fd7b885697843f38c480ffc369 /training/conf/experiment/cnn_transformer_paragraphs.yaml | |
parent | c9e50644ba14ae09aec5a44c12f8116bada26bab (diff) |
Update ConvTransform config
Diffstat (limited to 'training/conf/experiment/cnn_transformer_paragraphs.yaml')
-rw-r--r-- | training/conf/experiment/cnn_transformer_paragraphs.yaml | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/training/conf/experiment/cnn_transformer_paragraphs.yaml b/training/conf/experiment/cnn_transformer_paragraphs.yaml index 910d408..5ee5536 100644 --- a/training/conf/experiment/cnn_transformer_paragraphs.yaml +++ b/training/conf/experiment/cnn_transformer_paragraphs.yaml @@ -1,7 +1,10 @@ +# @package _global_ + defaults: - override /mapping: null - override /criterion: null - - override /datamodule: null + - override /callbacks: htr + - override /datamodule: iam_extended_paragraphs - override /network: null - override /model: null - override /lr_schedulers: null @@ -18,9 +21,10 @@ criterion: _target_: torch.nn.CrossEntropyLoss ignore_index: *ignore_index -mapping: - _target_: text_recognizer.data.emnist_mapping.EmnistMapping - extra_symbols: [ "\n" ] +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping + extra_symbols: [ "\n" ] callbacks: stochastic_weight_averaging: @@ -34,9 +38,9 @@ callbacks: optimizers: madgrad: _target_: madgrad.MADGRAD - lr: 2.0e-4 + lr: 3.0e-4 momentum: 0.9 - weight_decay: 0 + weight_decay: 5.0e-6 eps: 1.0e-6 parameters: network @@ -44,11 +48,11 @@ optimizers: lr_schedulers: network: _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 2.0e-4 + max_lr: 3.0e-4 total_steps: null epochs: *epochs steps_per_epoch: 632 - pct_start: 0.3 + pct_start: 0.03 anneal_strategy: cos cycle_momentum: true base_momentum: 0.85 @@ -67,10 +71,8 @@ datamodule: batch_size: 4 num_workers: 12 train_fraction: 0.8 - augment: true pin_memory: true - word_pieces: false - resize: null + << : *mapping network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer @@ -121,6 +123,7 @@ network: model: _target_: text_recognizer.models.transformer.TransformerLitModel + << : *mapping max_output_len: *max_output_len start_token: <s> end_token: <e> |