From 04c40f790e405ced6e6b90cf0a8aea268b9345c4 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 4 Aug 2021 15:15:26 +0200 Subject: Add char htr experiment, rename from ocr to htr, vqvae loss collapses --- training/callbacks/wandb_callbacks.py | 2 +- training/conf/callbacks/wandb_htr.yaml | 6 ++++++ training/conf/callbacks/wandb_htr_predictions.yaml | 3 +++ training/conf/callbacks/wandb_ocr.yaml | 6 ------ training/conf/callbacks/wandb_ocr_predictions.yaml | 3 --- training/conf/config.yaml | 2 +- training/conf/experiment/htr_char.yaml | 12 ++++++++++++ training/conf/experiment/vqvae.yaml | 3 ++- training/conf/mapping/characters.yaml | 2 ++ training/conf/mapping/emnist.yaml | 2 -- training/conf/network/decoder/transformer_decoder.yaml | 2 +- training/conf/trainer/default.yaml | 2 +- 12 files changed, 29 insertions(+), 16 deletions(-) create mode 100644 training/conf/callbacks/wandb_htr.yaml create mode 100644 training/conf/callbacks/wandb_htr_predictions.yaml delete mode 100644 training/conf/callbacks/wandb_ocr.yaml delete mode 100644 training/conf/callbacks/wandb_ocr_predictions.yaml create mode 100644 training/conf/experiment/htr_char.yaml create mode 100644 training/conf/mapping/characters.yaml delete mode 100644 training/conf/mapping/emnist.yaml (limited to 'training') diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index c750e4b..61d71df 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -119,7 +119,7 @@ class LogTextPredictions(Callback): ] experiment.log( - {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} + {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} ) def on_sanity_check_start( diff --git a/training/conf/callbacks/wandb_htr.yaml b/training/conf/callbacks/wandb_htr.yaml new file mode 100644 index 0000000..9c9a6da --- /dev/null +++ b/training/conf/callbacks/wandb_htr.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_ocr_predictions diff --git a/training/conf/callbacks/wandb_htr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml new file mode 100644 index 0000000..573fa96 --- /dev/null +++ b/training/conf/callbacks/wandb_htr_predictions.yaml @@ -0,0 +1,3 @@ +log_text_predictions: + _target_: callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml deleted file mode 100644 index 9c9a6da..0000000 --- a/training/conf/callbacks/wandb_ocr.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - default - - wandb_watch - - wandb_code - - wandb_checkpoints - - wandb_ocr_predictions diff --git a/training/conf/callbacks/wandb_ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml deleted file mode 100644 index 573fa96..0000000 --- a/training/conf/callbacks/wandb_ocr_predictions.yaml +++ /dev/null @@ -1,3 +0,0 @@ -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 6b74502..c606366 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - callbacks: wandb_ocr + - callbacks: wandb_htr - criterion: label_smoothing - datamodule: iam_extended_paragraphs - hydra: default diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml new file mode 100644 index 0000000..77126ae --- /dev/null +++ b/training/conf/experiment/htr_char.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - override /mapping: characters + +criterion: + ignore_index: 3 + +network: + num_classes: 89 + pad_index: 3 + max_output_len: 682 diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 13e5f34..699612e 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -8,6 +8,7 @@ defaults: trainer: max_epochs: 64 + gradient_clip_val: 0.25 datamodule: batch_size: 32 @@ -17,4 +18,4 @@ lr_scheduler: steps_per_epoch: 624 optimizer: - lr: 1.0e-2 + lr: 1.0e-3 diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml new file mode 100644 index 0000000..14e966b --- /dev/null +++ b/training/conf/mapping/characters.yaml @@ -0,0 +1,2 @@ +_target_: text_recognizer.data.emnist_mapping.EmnistMapping +extra_symbols: [ "\n" ] diff --git a/training/conf/mapping/emnist.yaml b/training/conf/mapping/emnist.yaml deleted file mode 100644 index 14e966b..0000000 --- a/training/conf/mapping/emnist.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: text_recognizer.data.emnist_mapping.EmnistMapping -extra_symbols: [ "\n" ] diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index c326c04..bc0678b 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.transformer.Decoder dim: 128 depth: 2 -num_heads: 8 +num_heads: 4 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: dim_head: 64 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index c665adc..0fa9ce1 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -8,7 +8,7 @@ gpus: 1 precision: 16 max_epochs: 512 terminate_on_nan: true -weights_summary: top +weights_summary: full limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 -- cgit v1.2.3-70-g09d2