From 9f786a0d052538be938bbf21d38a23754da9ab3b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 1 Oct 2021 00:04:07 +0200 Subject: Update config of IAM lines --- training/conf/experiment/cnn_htr_char_lines.yaml | 49 +++++++++++++++++------- 1 file changed, 36 insertions(+), 13 deletions(-) (limited to 'training') diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml index 08d9282..0f28ff9 100644 --- a/training/conf/experiment/cnn_htr_char_lines.yaml +++ b/training/conf/experiment/cnn_htr_char_lines.yaml @@ -11,13 +11,31 @@ defaults: criterion: - _target_: torch.nn.CrossEntropyLoss - ignore_index: 3 + _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss + smoothing: 0.1 + ignore_index: 1000 mapping: - _target_: text_recognizer.data.emnist_mapping.EmnistMapping + _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping + num_features: 1000 + tokens: iamdb_1kwp_tokens_1000.txt + lexicon: iamdb_1kwp_lex_1000.txt + data_dir: null + use_words: false + prepend_wordsep: false + special_tokens: [ , ,

] + # _target_: text_recognizer.data.emnist_mapping.EmnistMapping # extra_symbols: [ "\n" ] +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.8 + swa_lrs: 0.05 + annealing_epochs: 10 + annealing_strategy: cos + device: null + optimizers: madgrad: _target_: madgrad.MADGRAD @@ -30,16 +48,21 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 1024 - eta_min: 4.5e-6 - last_epoch: -1 + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.1 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-7 + eps: 1.0e-8 interval: epoch monitor: val/loss datamodule: _target_: text_recognizer.data.iam_lines.IAMLines - batch_size: 24 + batch_size: 16 num_workers: 12 train_fraction: 0.8 augment: true @@ -51,8 +74,8 @@ network: hidden_dim: 128 encoder_dim: 1280 dropout_rate: 0.2 - num_classes: 58 - pad_index: 3 + num_classes: 1006 + pad_index: 1000 encoder: _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet arch: b0 @@ -89,7 +112,7 @@ model: trainer: _target_: pytorch_lightning.Trainer - stochastic_weight_avg: false + stochastic_weight_avg: true auto_scale_batch_size: binsearch auto_lr_find: false gradient_clip_val: 0 @@ -98,7 +121,7 @@ trainer: precision: 16 max_epochs: 1024 terminate_on_nan: true - weights_summary: top + weights_summary: null limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 @@ -106,4 +129,4 @@ trainer: accumulate_grad_batches: 4 overfit_batches: 0.0 -# summary: [[1, 1, 56, 1024], [1, 89]] +summary: [[1, 1, 56, 1024], [1, 89]] -- cgit v1.2.3-70-g09d2