diff options
Diffstat (limited to 'training/conf')
| -rw-r--r-- | training/conf/experiment/cnn_htr_char_lines.yaml | 49 | 
1 files changed, 36 insertions, 13 deletions
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml index 08d9282..0f28ff9 100644 --- a/training/conf/experiment/cnn_htr_char_lines.yaml +++ b/training/conf/experiment/cnn_htr_char_lines.yaml @@ -11,13 +11,31 @@ defaults:  criterion: -  _target_: torch.nn.CrossEntropyLoss -  ignore_index: 3 +  _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss +  smoothing: 0.1  +  ignore_index: 1000  mapping: -  _target_: text_recognizer.data.emnist_mapping.EmnistMapping +  _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping +  num_features: 1000 +  tokens: iamdb_1kwp_tokens_1000.txt +  lexicon: iamdb_1kwp_lex_1000.txt +  data_dir: null +  use_words: false +  prepend_wordsep: false +  special_tokens: [ <s>, <e>, <p> ] +  # _target_: text_recognizer.data.emnist_mapping.EmnistMapping    # extra_symbols: [ "\n" ] +callbacks: +  stochastic_weight_averaging: +    _target_: pytorch_lightning.callbacks.StochasticWeightAveraging +    swa_epoch_start: 0.8 +    swa_lrs: 0.05 +    annealing_epochs: 10 +    annealing_strategy: cos +    device: null +  optimizers:    madgrad:      _target_: madgrad.MADGRAD @@ -30,16 +48,21 @@ optimizers:  lr_schedulers:    network: -    _target_: torch.optim.lr_scheduler.CosineAnnealingLR -    T_max: 1024 -    eta_min: 4.5e-6 -    last_epoch: -1 +    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau +    mode: min +    factor: 0.1 +    patience: 10 +    threshold: 1.0e-4 +    threshold_mode: rel +    cooldown: 0 +    min_lr: 1.0e-7 +    eps: 1.0e-8      interval: epoch      monitor: val/loss  datamodule:    _target_: text_recognizer.data.iam_lines.IAMLines -  batch_size: 24 +  batch_size: 16    num_workers: 12    train_fraction: 0.8    augment: true @@ -51,8 +74,8 @@ network:    hidden_dim: 128    encoder_dim: 1280    dropout_rate: 0.2 -  num_classes: 58 -  pad_index: 3 +  num_classes: 1006 +  pad_index: 1000    encoder:      _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet      arch: b0 @@ -89,7 +112,7 @@ model:  trainer:    _target_: pytorch_lightning.Trainer -  stochastic_weight_avg: false +  stochastic_weight_avg: true    auto_scale_batch_size: binsearch    auto_lr_find: false    gradient_clip_val: 0 @@ -98,7 +121,7 @@ trainer:    precision: 16    max_epochs: 1024    terminate_on_nan: true -  weights_summary: top +  weights_summary: null    limit_train_batches: 1.0     limit_val_batches: 1.0    limit_test_batches: 1.0 @@ -106,4 +129,4 @@ trainer:    accumulate_grad_batches: 4    overfit_batches: 0.0 -# summary: [[1, 1, 56, 1024], [1, 89]] +summary: [[1, 1, 56, 1024], [1, 89]]  |