summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 15:15:26 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 15:15:26 +0200
commit04c40f790e405ced6e6b90cf0a8aea268b9345c4 (patch)
treed5e05ee09fa99ee8d56d5373bde18626274a1fdd /training
parentd3afa310f77f47553586eeee58e3d3345a754e2c (diff)
Add char htr experiment, rename from ocr to htr, vqvae loss collapses
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py2
-rw-r--r--training/conf/callbacks/wandb_htr.yaml (renamed from training/conf/callbacks/wandb_ocr.yaml)0
-rw-r--r--training/conf/callbacks/wandb_htr_predictions.yaml (renamed from training/conf/callbacks/wandb_ocr_predictions.yaml)0
-rw-r--r--training/conf/config.yaml2
-rw-r--r--training/conf/experiment/htr_char.yaml12
-rw-r--r--training/conf/experiment/vqvae.yaml3
-rw-r--r--training/conf/mapping/characters.yaml (renamed from training/conf/mapping/emnist.yaml)0
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml2
-rw-r--r--training/conf/trainer/default.yaml2
9 files changed, 18 insertions, 5 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index c750e4b..61d71df 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -119,7 +119,7 @@ class LogTextPredictions(Callback):
]
experiment.log(
- {f"OCR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
+ {f"HTR/{experiment.name}/{stage}": wandb.Table(data=data, columns=columns)}
)
def on_sanity_check_start(
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_htr.yaml
index 9c9a6da..9c9a6da 100644
--- a/training/conf/callbacks/wandb_ocr.yaml
+++ b/training/conf/callbacks/wandb_htr.yaml
diff --git a/training/conf/callbacks/wandb_ocr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml
index 573fa96..573fa96 100644
--- a/training/conf/callbacks/wandb_ocr_predictions.yaml
+++ b/training/conf/callbacks/wandb_htr_predictions.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 6b74502..c606366 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,7 +1,7 @@
# @package _global_
defaults:
- - callbacks: wandb_ocr
+ - callbacks: wandb_htr
- criterion: label_smoothing
- datamodule: iam_extended_paragraphs
- hydra: default
diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml
new file mode 100644
index 0000000..77126ae
--- /dev/null
+++ b/training/conf/experiment/htr_char.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+defaults:
+ - override /mapping: characters
+
+criterion:
+ ignore_index: 3
+
+network:
+ num_classes: 89
+ pad_index: 3
+ max_output_len: 682
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index 13e5f34..699612e 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -8,6 +8,7 @@ defaults:
trainer:
max_epochs: 64
+ gradient_clip_val: 0.25
datamodule:
batch_size: 32
@@ -17,4 +18,4 @@ lr_scheduler:
steps_per_epoch: 624
optimizer:
- lr: 1.0e-2
+ lr: 1.0e-3
diff --git a/training/conf/mapping/emnist.yaml b/training/conf/mapping/characters.yaml
index 14e966b..14e966b 100644
--- a/training/conf/mapping/emnist.yaml
+++ b/training/conf/mapping/characters.yaml
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index c326c04..bc0678b 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.transformer.Decoder
dim: 128
depth: 2
-num_heads: 8
+num_heads: 4
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
dim_head: 64
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index c665adc..0fa9ce1 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -8,7 +8,7 @@ gpus: 1
precision: 16
max_epochs: 512
terminate_on_nan: true
-weights_summary: top
+weights_summary: full
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0