From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- training/conf/callbacks/wandb.yaml | 20 -------------------- training/conf/callbacks/wandb/checkpoints.yaml | 4 ++++ training/conf/callbacks/wandb/code.yaml | 3 +++ .../conf/callbacks/wandb/image_reconstructions.yaml | 0 training/conf/callbacks/wandb/ocr_predictions.yaml | 3 +++ training/conf/callbacks/wandb/watch.yaml | 4 ++++ training/conf/callbacks/wandb_ocr.yaml | 6 ++++++ training/conf/config.yaml | 18 ++++++++---------- training/conf/criterion/label_smoothing.yaml | 2 +- training/conf/hydra/default.yaml | 6 ++++++ training/conf/mapping/word_piece.yaml | 9 +++++++++ training/conf/model/lit_transformer.yaml | 5 +---- training/conf/model/mapping/word_piece.yaml | 9 --------- training/conf/network/conv_transformer.yaml | 2 +- .../conf/network/decoder/transformer_decoder.yaml | 4 ++-- training/conf/trainer/default.yaml | 6 +++++- 16 files changed, 53 insertions(+), 48 deletions(-) delete mode 100644 training/conf/callbacks/wandb.yaml create mode 100644 training/conf/callbacks/wandb/checkpoints.yaml create mode 100644 training/conf/callbacks/wandb/code.yaml create mode 100644 training/conf/callbacks/wandb/image_reconstructions.yaml create mode 100644 training/conf/callbacks/wandb/ocr_predictions.yaml create mode 100644 training/conf/callbacks/wandb/watch.yaml create mode 100644 training/conf/callbacks/wandb_ocr.yaml create mode 100644 training/conf/hydra/default.yaml create mode 100644 training/conf/mapping/word_piece.yaml delete mode 100644 training/conf/model/mapping/word_piece.yaml (limited to 'training/conf') diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml deleted file mode 100644 index 0017e11..0000000 --- a/training/conf/callbacks/wandb.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - default.yaml - -watch_model: - _target_: callbacks.wandb_callbacks.WatchModel - log: all - log_freq: 100 - -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact - project_dir: ${work_dir}/text_recognizer - -upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true - -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml new file mode 100644 index 0000000..a4a16ff --- /dev/null +++ b/training/conf/callbacks/wandb/checkpoints.yaml @@ -0,0 +1,4 @@ +upload_ckpts_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: checkpoints/ + upload_best_only: true diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml new file mode 100644 index 0000000..35f6ea3 --- /dev/null +++ b/training/conf/callbacks/wandb/code.yaml @@ -0,0 +1,3 @@ +upload_code_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact + project_dir: ${work_dir}/text_recognizer diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml new file mode 100644 index 0000000..e69de29 diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml new file mode 100644 index 0000000..573fa96 --- /dev/null +++ b/training/conf/callbacks/wandb/ocr_predictions.yaml @@ -0,0 +1,3 @@ +log_text_predictions: + _target_: callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml new file mode 100644 index 0000000..511608c --- /dev/null +++ b/training/conf/callbacks/wandb/watch.yaml @@ -0,0 +1,4 @@ +watch_model: + _target_: callbacks.wandb_callbacks.WatchModel + log: all + log_freq: 100 diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml new file mode 100644 index 0000000..efa3dda --- /dev/null +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb/watch + - wandb/code + - wandb/checkpoints + - wandb/ocr_predictions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a8e718e..93215ed 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,19 +1,17 @@ defaults: - - network: vqvae - - criterion: mse - - optimizer: madgrad - - lr_scheduler: one_cycle - - model: lit_vqvae + - callbacks: wandb_ocr + - criterion: label_smoothing - dataset: iam_extended_paragraphs + - hydra: default + - lr_scheduler: one_cycle + - mapping: word_piece + - model: lit_transformer + - network: conv_transformer + - optimizer: madgrad - trainer: default - - callbacks: - - checkpoint - - learning_rate_monitor seed: 4711 -wandb: false tune: false train: true test: true -load_checkpoint: null logging: INFO diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index ee47c59..13daba8 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.criterion.label_smoothing +_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss label_smoothing: 0.1 vocab_size: 1006 ignore_index: 1002 diff --git a/training/conf/hydra/default.yaml b/training/conf/hydra/default.yaml new file mode 100644 index 0000000..dfd9721 --- /dev/null +++ b/training/conf/hydra/default.yaml @@ -0,0 +1,6 @@ +# output paths for hydra logs +run: + dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} +sweep: + dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml new file mode 100644 index 0000000..3792523 --- /dev/null +++ b/training/conf/mapping/word_piece.yaml @@ -0,0 +1,9 @@ +_target_: text_recognizer.data.mappings.WordPieceMapping +num_features: 1000 +tokens: iamdb_1kwp_tokens_1000.txt +lexicon: iamdb_1kwp_lex_1000.txt +data_dir: null +use_words: false +prepend_wordsep: false +special_tokens: [ , ,

] +extra_symbols: [ \n ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 5341d8e..6ffde4e 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,8 +1,5 @@ -defaults: - - mapping: word_piece - _target_: text_recognizer.models.transformer.TransformerLitModel -interval: null +interval: step monitor: val/loss ignore_tokens: [ , ,

] start_token: diff --git a/training/conf/model/mapping/word_piece.yaml b/training/conf/model/mapping/word_piece.yaml deleted file mode 100644 index 39e2ba4..0000000 --- a/training/conf/model/mapping/word_piece.yaml +++ /dev/null @@ -1,9 +0,0 @@ -_target_: text_recognizer.data.mappings.WordPieceMapping -num_features: 1000 -tokens: iamdb_1kwp_tokens_1000.txt -lexicon: iamdb_1kwp_lex_1000.txt -data_dir: null -use_words: false -prepend_wordsep: false -special_tokens: ["", "", "

"] -extra_symbols: ["\n"] diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 7d57a2d..a97157d 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 256 +hidden_dim: 96 dropout_rate: 0.2 max_output_len: 451 num_classes: 1006 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 3122de1..90b9d8a 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 256 +dim: 96 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 64 + dim_head: 16 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 5ed6552..c665adc 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -6,6 +6,10 @@ gradient_clip_val: 0 fast_dev_run: false gpus: 1 precision: 16 -max_epochs: 64 +max_epochs: 512 terminate_on_nan: true weights_summary: top +limit_train_batches: 1.0 +limit_val_batches: 1.0 +limit_test_batches: 1.0 +resume_from_checkpoint: null -- cgit v1.2.3-70-g09d2