summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-01 00:04:07 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-01 00:04:07 +0200
commit9f786a0d052538be938bbf21d38a23754da9ab3b (patch)
treef944ae827582e47e211dd0dd51568e571aaca20a /training
parenta0d0b61563802e6f5aa79fd13c3b25ac32e5b002 (diff)
Update config of IAM lines
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/cnn_htr_char_lines.yaml49
1 files changed, 36 insertions, 13 deletions
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml
index 08d9282..0f28ff9 100644
--- a/training/conf/experiment/cnn_htr_char_lines.yaml
+++ b/training/conf/experiment/cnn_htr_char_lines.yaml
@@ -11,13 +11,31 @@ defaults:
criterion:
- _target_: torch.nn.CrossEntropyLoss
- ignore_index: 3
+ _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+ smoothing: 0.1
+ ignore_index: 1000
mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
+ _target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
+ num_features: 1000
+ tokens: iamdb_1kwp_tokens_1000.txt
+ lexicon: iamdb_1kwp_lex_1000.txt
+ data_dir: null
+ use_words: false
+ prepend_wordsep: false
+ special_tokens: [ <s>, <e>, <p> ]
+ # _target_: text_recognizer.data.emnist_mapping.EmnistMapping
# extra_symbols: [ "\n" ]
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
optimizers:
madgrad:
_target_: madgrad.MADGRAD
@@ -30,16 +48,21 @@ optimizers:
lr_schedulers:
network:
- _target_: torch.optim.lr_scheduler.CosineAnnealingLR
- T_max: 1024
- eta_min: 4.5e-6
- last_epoch: -1
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.1
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-7
+ eps: 1.0e-8
interval: epoch
monitor: val/loss
datamodule:
_target_: text_recognizer.data.iam_lines.IAMLines
- batch_size: 24
+ batch_size: 16
num_workers: 12
train_fraction: 0.8
augment: true
@@ -51,8 +74,8 @@ network:
hidden_dim: 128
encoder_dim: 1280
dropout_rate: 0.2
- num_classes: 58
- pad_index: 3
+ num_classes: 1006
+ pad_index: 1000
encoder:
_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
arch: b0
@@ -89,7 +112,7 @@ model:
trainer:
_target_: pytorch_lightning.Trainer
- stochastic_weight_avg: false
+ stochastic_weight_avg: true
auto_scale_batch_size: binsearch
auto_lr_find: false
gradient_clip_val: 0
@@ -98,7 +121,7 @@ trainer:
precision: 16
max_epochs: 1024
terminate_on_nan: true
- weights_summary: top
+ weights_summary: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
@@ -106,4 +129,4 @@ trainer:
accumulate_grad_batches: 4
overfit_batches: 0.0
-# summary: [[1, 1, 56, 1024], [1, 89]]
+summary: [[1, 1, 56, 1024], [1, 89]]