diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:06:54 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:06:54 +0200 |
commit | b69254ce3135c112e29f7f1c986b7f0817da0c33 (patch) | |
tree | e596a719b445c0cbbd4108079206ec9d14de1437 /training | |
parent | cba94bdcab90f288dd1172607500ba2b28279736 (diff) |
Update configs
Diffstat (limited to 'training')
27 files changed, 13 insertions, 112 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml deleted file mode 100644 index b4101d8..0000000 --- a/training/conf/callbacks/checkpoint.yaml +++ /dev/null @@ -1,9 +0,0 @@ -model_checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - monitor: val/loss # name of the logged metric which determines when model is improving - save_top_k: 1 # save k best models (determined by above metric) - save_last: true # additionaly always save model from last epoch - mode: min # can be "max" or "min" - verbose: false - dirpath: checkpoints/ - filename: "{epoch:02d}" diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml index 658fc03..c184039 100644 --- a/training/conf/callbacks/default.yaml +++ b/training/conf/callbacks/default.yaml @@ -1,3 +1,3 @@ defaults: - - checkpoint - - learning_rate_monitor + - lightning: checkpoint + - lightning: learning_rate_monitor diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml deleted file mode 100644 index a188df3..0000000 --- a/training/conf/callbacks/early_stopping.yaml +++ /dev/null @@ -1,6 +0,0 @@ -early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: val/loss # name of the logged metric which determines when model is improving - patience: 16 # how many epochs of not improving until training stops - mode: min # can be "max" or "min" - min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml deleted file mode 100644 index 4a14e1f..0000000 --- a/training/conf/callbacks/learning_rate_monitor.yaml +++ /dev/null @@ -1,4 +0,0 @@ -learning_rate_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: step - log_momentum: false diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml deleted file mode 100644 index 73f8c66..0000000 --- a/training/conf/callbacks/swa.yaml +++ /dev/null @@ -1,7 +0,0 @@ -stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.8 - swa_lrs: 0.05 - annealing_epochs: 10 - annealing_strategy: cos - device: null diff --git a/training/conf/callbacks/wandb_checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml deleted file mode 100644 index a4a16ff..0000000 --- a/training/conf/callbacks/wandb_checkpoints.yaml +++ /dev/null @@ -1,4 +0,0 @@ -upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true diff --git a/training/conf/callbacks/wandb_config.yaml b/training/conf/callbacks/wandb_config.yaml deleted file mode 100644 index 747a7c6..0000000 --- a/training/conf/callbacks/wandb_config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact diff --git a/training/conf/callbacks/wandb_htr.yaml b/training/conf/callbacks/wandb_htr.yaml deleted file mode 100644 index f8c1ef7..0000000 --- a/training/conf/callbacks/wandb_htr.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - default - - wandb_watch - - wandb_config - - wandb_checkpoints - - wandb_htr_predictions diff --git a/training/conf/callbacks/wandb_htr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml deleted file mode 100644 index 468b6e0..0000000 --- a/training/conf/callbacks/wandb_htr_predictions.yaml +++ /dev/null @@ -1,4 +0,0 @@ -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 - log_train: false diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml deleted file mode 100644 index fabfe31..0000000 --- a/training/conf/callbacks/wandb_image_reconstructions.yaml +++ /dev/null @@ -1,5 +0,0 @@ -log_image_reconstruction: - _target_: callbacks.wandb_callbacks.LogReconstuctedImages - num_samples: 8 - log_train: true - use_sigmoid: true diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml deleted file mode 100644 index ffc467f..0000000 --- a/training/conf/callbacks/wandb_vae.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - default - - wandb_watch - - wandb_checkpoints - - wandb_image_reconstructions - - wandb_config diff --git a/training/conf/callbacks/wandb_watch.yaml b/training/conf/callbacks/wandb_watch.yaml deleted file mode 100644 index 511608c..0000000 --- a/training/conf/callbacks/wandb_watch.yaml +++ /dev/null @@ -1,4 +0,0 @@ -watch_model: - _target_: callbacks.wandb_callbacks.WatchModel - log: all - log_freq: 100 diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 9ed366f..a58891b 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - callbacks: wandb_htr + - callbacks: htr - criterion: label_smoothing - datamodule: iam_extended_paragraphs - hydra: default diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index a0ffe56..6fa8206 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -2,7 +2,6 @@ _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs batch_size: 4 num_workers: 12 train_fraction: 0.8 -augment: true pin_memory: false -word_pieces: true -resize: null +transform: transform/paragraphs.yaml +test_transform: transform/paragraphs.yaml diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml index d3c2af8..36e7093 100644 --- a/training/conf/datamodule/iam_lines.yaml +++ b/training/conf/datamodule/iam_lines.yaml @@ -2,6 +2,6 @@ _target_: text_recognizer.data.iam_lines.IAMLines batch_size: 8 num_workers: 12 train_fraction: 0.8 -augment: true pin_memory: false -# word_pieces: true +transform: transform/iam_lines.yaml +test_transform: test_transform/iam_lines.yaml diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml index 0d62a73..53f6d91 100644 --- a/training/conf/experiment/cnn_htr_char_lines.yaml +++ b/training/conf/experiment/cnn_htr_char_lines.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /criterion: null @@ -19,11 +17,10 @@ criterion: _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss smoothing: 0.1 ignore_index: *ignore_index - # _target_: torch.nn.CrossEntropyLoss - # ignore_index: *ignore_index mapping: - _target_: text_recognizer.data.emnist_mapping.EmnistMapping + mapping: &mapping + _target_: text_recognizer.data.emnist_mapping.EmnistMapping callbacks: stochastic_weight_averaging: @@ -73,6 +70,7 @@ datamodule: augment: true pin_memory: true word_pieces: false + <<: *mapping network: _target_: text_recognizer.networks.conv_transformer.ConvTransformer @@ -80,6 +78,7 @@ network: hidden_dim: &hidden_dim 128 encoder_dim: 1280 dropout_rate: 0.2 + <<: *mapping num_classes: *num_classes pad_index: *ignore_index encoder: diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml index 79075cd..f467b74 100644 --- a/training/conf/experiment/cnn_htr_wp_lines.yaml +++ b/training/conf/experiment/cnn_htr_wp_lines.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /criterion: null diff --git a/training/conf/experiment/cnn_transformer_paragraphs.yaml b/training/conf/experiment/cnn_transformer_paragraphs.yaml index 8feb1bc..910d408 100644 --- a/training/conf/experiment/cnn_transformer_paragraphs.yaml +++ b/training/conf/experiment/cnn_transformer_paragraphs.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /criterion: null diff --git a/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml b/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml index 1c9bba1..499a609 100644 --- a/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml +++ b/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /criterion: null diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 572c320..98f3346 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /network: vqvae - override /criterion: null diff --git a/training/conf/experiment/vqgan_htr_char.yaml b/training/conf/experiment/vqgan_htr_char.yaml index 426524f..af3fa40 100644 --- a/training/conf/experiment/vqgan_htr_char.yaml +++ b/training/conf/experiment/vqgan_htr_char.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /network: null diff --git a/training/conf/experiment/vqgan_htr_char_iam_lines.yaml b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml index 9f4791f..27fdfda 100644 --- a/training/conf/experiment/vqgan_htr_char_iam_lines.yaml +++ b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /mapping: null - override /criterion: null diff --git a/training/conf/experiment/vqgan_iam_lines.yaml b/training/conf/experiment/vqgan_iam_lines.yaml index 8bdf415..890948c 100644 --- a/training/conf/experiment/vqgan_iam_lines.yaml +++ b/training/conf/experiment/vqgan_iam_lines.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /network: null - override /criterion: null diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 6e42690..d069aef 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -1,5 +1,3 @@ -# @package _global_ - defaults: - override /network: vqvae - override /criterion: mse diff --git a/training/conf/experiment/vqvae_pixelcnn.yaml b/training/conf/experiment/vqvae_pixelcnn.yaml deleted file mode 100644 index 4fae782..0000000 --- a/training/conf/experiment/vqvae_pixelcnn.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# @package _global_ - -defaults: - - override /network: vqvae_pixelcnn - - override /criterion: mae - - override /model: lit_vqvae - - override /callbacks: wandb_vae - - override /lr_schedulers: - - cosine_annealing - -trainer: - max_epochs: 256 - # gradient_clip_val: 0.25 - -datamodule: - batch_size: 8 - -# lr_scheduler: - # epochs: 64 - # steps_per_epoch: 1245 - -# optimizer: - # lr: 1.0e-3 - diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml index 14e966b..41a26a3 100644 --- a/training/conf/mapping/characters.yaml +++ b/training/conf/mapping/characters.yaml @@ -1,2 +1,2 @@ -_target_: text_recognizer.data.emnist_mapping.EmnistMapping +_target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping extra_symbols: [ "\n" ] diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index 48384f5..ca8dd9c 100644 --- a/training/conf/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping +_target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping num_features: 1000 tokens: iamdb_1kwp_tokens_1000.txt lexicon: iamdb_1kwp_lex_1000.txt |