diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-01 00:04:07 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-01 00:04:07 +0200 |
commit | 9f786a0d052538be938bbf21d38a23754da9ab3b (patch) | |
tree | f944ae827582e47e211dd0dd51568e571aaca20a /training/conf/experiment | |
parent | a0d0b61563802e6f5aa79fd13c3b25ac32e5b002 (diff) |
Update config of IAM lines
Diffstat (limited to 'training/conf/experiment')
-rw-r--r-- | training/conf/experiment/cnn_htr_char_lines.yaml | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml index 08d9282..0f28ff9 100644 --- a/training/conf/experiment/cnn_htr_char_lines.yaml +++ b/training/conf/experiment/cnn_htr_char_lines.yaml @@ -11,13 +11,31 @@ defaults: criterion: - _target_: torch.nn.CrossEntropyLoss - ignore_index: 3 + _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss + smoothing: 0.1 + ignore_index: 1000 mapping: - _target_: text_recognizer.data.emnist_mapping.EmnistMapping + _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping + num_features: 1000 + tokens: iamdb_1kwp_tokens_1000.txt + lexicon: iamdb_1kwp_lex_1000.txt + data_dir: null + use_words: false + prepend_wordsep: false + special_tokens: [ <s>, <e>, <p> ] + # _target_: text_recognizer.data.emnist_mapping.EmnistMapping # extra_symbols: [ "\n" ] +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null + optimizers: madgrad: _target_: madgrad.MADGRAD @@ -30,16 +48,21 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 1024 - eta_min: 4.5e-6 - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.1 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-7 + eps: 1.0e-8 interval: epoch monitor: val/loss datamodule: _target_: text_recognizer.data.iam_lines.IAMLines - batch_size: 24 + batch_size: 16 num_workers: 12 train_fraction: 0.8 augment: true @@ -51,8 +74,8 @@ network: hidden_dim: 128 encoder_dim: 1280 dropout_rate: 0.2 - num_classes: 58 - pad_index: 3 + num_classes: 1006 + pad_index: 1000 encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet arch: b0 @@ -89,7 +112,7 @@ model: trainer: _target_: pytorch_lightning.Trainer - stochastic_weight_avg: false + stochastic_weight_avg: true auto_scale_batch_size: binsearch auto_lr_find: false gradient_clip_val: 0 @@ -98,7 +121,7 @@ trainer: precision: 16 max_epochs: 1024 terminate_on_nan: true - weights_summary: top + weights_summary: null limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 @@ -106,4 +129,4 @@ trainer: accumulate_grad_batches: 4 overfit_batches: 0.0 -# summary: [[1, 1, 56, 1024], [1, 89]] +summary: [[1, 1, 56, 1024], [1, 89]] |