From bd4bd443f339e95007bfdabf3e060db720f4d4b9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 3 Aug 2021 18:18:48 +0200 Subject: Training working, multiple bug fixes --- training/conf/callbacks/checkpoint.yaml | 2 +- training/conf/callbacks/wandb/checkpoints.yaml | 4 ---- training/conf/callbacks/wandb/code.yaml | 3 --- .../conf/callbacks/wandb/image_reconstructions.yaml | 0 training/conf/callbacks/wandb/ocr_predictions.yaml | 3 --- training/conf/callbacks/wandb/watch.yaml | 4 ---- training/conf/callbacks/wandb_checkpoints.yaml | 4 ++++ training/conf/callbacks/wandb_code.yaml | 3 +++ .../conf/callbacks/wandb_image_reconstructions.yaml | 0 training/conf/callbacks/wandb_ocr.yaml | 8 ++++---- training/conf/callbacks/wandb_ocr_predictions.yaml | 3 +++ training/conf/callbacks/wandb_watch.yaml | 4 ++++ training/conf/config.yaml | 21 ++++++++++++++++++++- training/conf/criterion/label_smoothing.yaml | 5 ++--- .../conf/datamodule/iam_extended_paragraphs.yaml | 3 ++- training/conf/lr_scheduler/one_cycle.yaml | 4 ++-- training/conf/mapping/word_piece.yaml | 4 ++-- training/conf/model/lit_transformer.yaml | 2 +- training/conf/network/conv_transformer.yaml | 1 - .../conf/network/decoder/transformer_decoder.yaml | 1 + 20 files changed, 49 insertions(+), 30 deletions(-) delete mode 100644 training/conf/callbacks/wandb/checkpoints.yaml delete mode 100644 training/conf/callbacks/wandb/code.yaml delete mode 100644 training/conf/callbacks/wandb/image_reconstructions.yaml delete mode 100644 training/conf/callbacks/wandb/ocr_predictions.yaml delete mode 100644 training/conf/callbacks/wandb/watch.yaml create mode 100644 training/conf/callbacks/wandb_checkpoints.yaml create mode 100644 training/conf/callbacks/wandb_code.yaml create mode 100644 training/conf/callbacks/wandb_image_reconstructions.yaml create mode 100644 training/conf/callbacks/wandb_ocr_predictions.yaml create mode 100644 training/conf/callbacks/wandb_watch.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index db34cb1..b4101d8 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -6,4 +6,4 @@ model_checkpoint: mode: min # can be "max" or "min" verbose: false dirpath: checkpoints/ - filename: {epoch:02d} + filename: "{epoch:02d}" diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml deleted file mode 100644 index a4a16ff..0000000 --- a/training/conf/callbacks/wandb/checkpoints.yaml +++ /dev/null @@ -1,4 +0,0 @@ -upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml deleted file mode 100644 index 35f6ea3..0000000 --- a/training/conf/callbacks/wandb/code.yaml +++ /dev/null @@ -1,3 +0,0 @@ -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact - project_dir: ${work_dir}/text_recognizer diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml deleted file mode 100644 index 573fa96..0000000 --- a/training/conf/callbacks/wandb/ocr_predictions.yaml +++ /dev/null @@ -1,3 +0,0 @@ -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml deleted file mode 100644 index 511608c..0000000 --- a/training/conf/callbacks/wandb/watch.yaml +++ /dev/null @@ -1,4 +0,0 @@ -watch_model: - _target_: callbacks.wandb_callbacks.WatchModel - log: all - log_freq: 100 diff --git a/training/conf/callbacks/wandb_checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml new file mode 100644 index 0000000..a4a16ff --- /dev/null +++ b/training/conf/callbacks/wandb_checkpoints.yaml @@ -0,0 +1,4 @@ +upload_ckpts_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: checkpoints/ + upload_best_only: true diff --git a/training/conf/callbacks/wandb_code.yaml b/training/conf/callbacks/wandb_code.yaml new file mode 100644 index 0000000..35f6ea3 --- /dev/null +++ b/training/conf/callbacks/wandb_code.yaml @@ -0,0 +1,3 @@ +upload_code_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact + project_dir: ${work_dir}/text_recognizer diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml new file mode 100644 index 0000000..e69de29 diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml index efa3dda..9c9a6da 100644 --- a/training/conf/callbacks/wandb_ocr.yaml +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -1,6 +1,6 @@ defaults: - default - - wandb/watch - - wandb/code - - wandb/checkpoints - - wandb/ocr_predictions + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_ocr_predictions diff --git a/training/conf/callbacks/wandb_ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml new file mode 100644 index 0000000..573fa96 --- /dev/null +++ b/training/conf/callbacks/wandb_ocr_predictions.yaml @@ -0,0 +1,3 @@ +log_text_predictions: + _target_: callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 diff --git a/training/conf/callbacks/wandb_watch.yaml b/training/conf/callbacks/wandb_watch.yaml new file mode 100644 index 0000000..511608c --- /dev/null +++ b/training/conf/callbacks/wandb_watch.yaml @@ -0,0 +1,4 @@ +watch_model: + _target_: callbacks.wandb_callbacks.WatchModel + log: all + log_freq: 100 diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 93215ed..782bcbb 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,8 +1,9 @@ defaults: - callbacks: wandb_ocr - criterion: label_smoothing - - dataset: iam_extended_paragraphs + - datamodule: iam_extended_paragraphs - hydra: default + - logger: wandb - lr_scheduler: one_cycle - mapping: word_piece - model: lit_transformer @@ -15,3 +16,21 @@ tune: false train: true test: true logging: INFO + +# path to original working directory +# hydra hijacks working directory by changing it to the current log directory, +# so it's useful to have this path as a special variable +# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory +work_dir: ${hydra:runtime.cwd} + +# use `python run.py debug=true` for easy debugging! +# this will run 1 train, val and test loop with only 1 batch +# equivalent to running `python run.py trainer.fast_dev_run=true` +# (this is placed here just for easier access from command line) +debug: False + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: True diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index 13daba8..684b5bb 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,3 @@ -_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss -label_smoothing: 0.1 -vocab_size: 1006 +_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss +smoothing: 0.1 ignore_index: 1002 diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index 3070b56..2d1a03e 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -1,5 +1,6 @@ _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs -batch_size: 32 +batch_size: 4 num_workers: 12 train_fraction: 0.8 augment: true +pin_memory: false diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index 5afdf81..eecee8a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,8 +1,8 @@ _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 1.0e-3 total_steps: null -epochs: null -steps_per_epoch: null +epochs: 512 +steps_per_epoch: 4992 pct_start: 0.3 anneal_strategy: cos cycle_momentum: true diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index 3792523..48384f5 100644 --- a/training/conf/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.data.mappings.WordPieceMapping +_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping num_features: 1000 tokens: iamdb_1kwp_tokens_1000.txt lexicon: iamdb_1kwp_lex_1000.txt @@ -6,4 +6,4 @@ data_dir: null use_words: false prepend_wordsep: false special_tokens: [ , ,

] -extra_symbols: [ \n ] +extra_symbols: [ "\n" ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 6ffde4e..c190151 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,7 +1,7 @@ _target_: text_recognizer.models.transformer.TransformerLitModel interval: step monitor: val/loss -ignore_tokens: [ , ,

] +max_output_len: 451 start_token: end_token: pad_token:

diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index a97157d..f76e892 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: 96 dropout_rate: 0.2 -max_output_len: 451 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 90b9d8a..eb80f64 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -18,3 +18,4 @@ ff_kwargs: dropout_rate: 0.2 cross_attend: true pre_norm: true +rotary_emb: null -- cgit v1.2.3-70-g09d2