diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 2 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_htr.yaml (renamed from training/conf/callbacks/wandb_ocr.yaml) | 0 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_htr_predictions.yaml (renamed from training/conf/callbacks/wandb_ocr_predictions.yaml) | 0 | ||||
-rw-r--r-- | training/conf/config.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/htr_char.yaml | 12 | ||||
-rw-r--r-- | training/conf/experiment/vqvae.yaml | 3 | ||||
-rw-r--r-- | training/conf/mapping/characters.yaml (renamed from training/conf/mapping/emnist.yaml) | 0 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 2 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 2 |
9 files changed, 18 insertions, 5 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index c750e4b..61d71df 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -119,7 +119,7 @@ class LogTextPredictions(Callback): ] experiment.log( - {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} + {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)} ) def on_sanity_check_start( diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_htr.yaml index 9c9a6da..9c9a6da 100644 --- a/training/conf/callbacks/wandb_ocr.yaml +++ b/training/conf/callbacks/wandb_htr.yaml diff --git a/training/conf/callbacks/wandb_ocr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml index 573fa96..573fa96 100644 --- a/training/conf/callbacks/wandb_ocr_predictions.yaml +++ b/training/conf/callbacks/wandb_htr_predictions.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 6b74502..c606366 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - callbacks: wandb_ocr + - callbacks: wandb_htr - criterion: label_smoothing - datamodule: iam_extended_paragraphs - hydra: default diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml new file mode 100644 index 0000000..77126ae --- /dev/null +++ b/training/conf/experiment/htr_char.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - override /mapping: characters + +criterion: + ignore_index: 3 + +network: + num_classes: 89 + pad_index: 3 + max_output_len: 682 diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 13e5f34..699612e 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -8,6 +8,7 @@ defaults: trainer: max_epochs: 64 + gradient_clip_val: 0.25 datamodule: batch_size: 32 @@ -17,4 +18,4 @@ lr_scheduler: steps_per_epoch: 624 optimizer: - lr: 1.0e-2 + lr: 1.0e-3 diff --git a/training/conf/mapping/emnist.yaml b/training/conf/mapping/characters.yaml index 14e966b..14e966b 100644 --- a/training/conf/mapping/emnist.yaml +++ b/training/conf/mapping/characters.yaml diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index c326c04..bc0678b 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.transformer.Decoder dim: 128 depth: 2 -num_heads: 8 +num_heads: 4 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: dim_head: 64 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index c665adc..0fa9ce1 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -8,7 +8,7 @@ gpus: 1 precision: 16 max_epochs: 512 terminate_on_nan: true -weights_summary: top +weights_summary: full limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 |