summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:06:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:06:54 +0200
commitb69254ce3135c112e29f7f1c986b7f0817da0c33 (patch)
treee596a719b445c0cbbd4108079206ec9d14de1437 /training
parentcba94bdcab90f288dd1172607500ba2b28279736 (diff)
Update configs
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/checkpoint.yaml9
-rw-r--r--training/conf/callbacks/default.yaml4
-rw-r--r--training/conf/callbacks/early_stopping.yaml6
-rw-r--r--training/conf/callbacks/learning_rate_monitor.yaml4
-rw-r--r--training/conf/callbacks/swa.yaml7
-rw-r--r--training/conf/callbacks/wandb_checkpoints.yaml4
-rw-r--r--training/conf/callbacks/wandb_config.yaml2
-rw-r--r--training/conf/callbacks/wandb_htr.yaml6
-rw-r--r--training/conf/callbacks/wandb_htr_predictions.yaml4
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml5
-rw-r--r--training/conf/callbacks/wandb_vae.yaml6
-rw-r--r--training/conf/callbacks/wandb_watch.yaml4
-rw-r--r--training/conf/config.yaml2
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml5
-rw-r--r--training/conf/datamodule/iam_lines.yaml4
-rw-r--r--training/conf/experiment/cnn_htr_char_lines.yaml9
-rw-r--r--training/conf/experiment/cnn_htr_wp_lines.yaml2
-rw-r--r--training/conf/experiment/cnn_transformer_paragraphs.yaml2
-rw-r--r--training/conf/experiment/cnn_transformer_paragraphs_wp.yaml2
-rw-r--r--training/conf/experiment/vqgan.yaml2
-rw-r--r--training/conf/experiment/vqgan_htr_char.yaml2
-rw-r--r--training/conf/experiment/vqgan_htr_char_iam_lines.yaml2
-rw-r--r--training/conf/experiment/vqgan_iam_lines.yaml2
-rw-r--r--training/conf/experiment/vqvae.yaml2
-rw-r--r--training/conf/experiment/vqvae_pixelcnn.yaml24
-rw-r--r--training/conf/mapping/characters.yaml2
-rw-r--r--training/conf/mapping/word_piece.yaml2
27 files changed, 13 insertions, 112 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
deleted file mode 100644
index b4101d8..0000000
--- a/training/conf/callbacks/checkpoint.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-model_checkpoint:
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
- monitor: val/loss # name of the logged metric which determines when model is improving
- save_top_k: 1 # save k best models (determined by above metric)
- save_last: true # additionaly always save model from last epoch
- mode: min # can be "max" or "min"
- verbose: false
- dirpath: checkpoints/
- filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
index 658fc03..c184039 100644
--- a/training/conf/callbacks/default.yaml
+++ b/training/conf/callbacks/default.yaml
@@ -1,3 +1,3 @@
defaults:
- - checkpoint
- - learning_rate_monitor
+ - lightning: checkpoint
+ - lightning: learning_rate_monitor
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
deleted file mode 100644
index a188df3..0000000
--- a/training/conf/callbacks/early_stopping.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-early_stopping:
- _target_: pytorch_lightning.callbacks.EarlyStopping
- monitor: val/loss # name of the logged metric which determines when model is improving
- patience: 16 # how many epochs of not improving until training stops
- mode: min # can be "max" or "min"
- min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/training/conf/callbacks/learning_rate_monitor.yaml b/training/conf/callbacks/learning_rate_monitor.yaml
deleted file mode 100644
index 4a14e1f..0000000
--- a/training/conf/callbacks/learning_rate_monitor.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-learning_rate_monitor:
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
- logging_interval: step
- log_momentum: false
diff --git a/training/conf/callbacks/swa.yaml b/training/conf/callbacks/swa.yaml
deleted file mode 100644
index 73f8c66..0000000
--- a/training/conf/callbacks/swa.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-stochastic_weight_averaging:
- _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
diff --git a/training/conf/callbacks/wandb_checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml
deleted file mode 100644
index a4a16ff..0000000
--- a/training/conf/callbacks/wandb_checkpoints.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-upload_ckpts_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
- ckpt_dir: checkpoints/
- upload_best_only: true
diff --git a/training/conf/callbacks/wandb_config.yaml b/training/conf/callbacks/wandb_config.yaml
deleted file mode 100644
index 747a7c6..0000000
--- a/training/conf/callbacks/wandb_config.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-upload_code_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadConfigAsArtifact
diff --git a/training/conf/callbacks/wandb_htr.yaml b/training/conf/callbacks/wandb_htr.yaml
deleted file mode 100644
index f8c1ef7..0000000
--- a/training/conf/callbacks/wandb_htr.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-defaults:
- - default
- - wandb_watch
- - wandb_config
- - wandb_checkpoints
- - wandb_htr_predictions
diff --git a/training/conf/callbacks/wandb_htr_predictions.yaml b/training/conf/callbacks/wandb_htr_predictions.yaml
deleted file mode 100644
index 468b6e0..0000000
--- a/training/conf/callbacks/wandb_htr_predictions.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-log_text_predictions:
- _target_: callbacks.wandb_callbacks.LogTextPredictions
- num_samples: 8
- log_train: false
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
deleted file mode 100644
index fabfe31..0000000
--- a/training/conf/callbacks/wandb_image_reconstructions.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-log_image_reconstruction:
- _target_: callbacks.wandb_callbacks.LogReconstuctedImages
- num_samples: 8
- log_train: true
- use_sigmoid: true
diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml
deleted file mode 100644
index ffc467f..0000000
--- a/training/conf/callbacks/wandb_vae.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-defaults:
- - default
- - wandb_watch
- - wandb_checkpoints
- - wandb_image_reconstructions
- - wandb_config
diff --git a/training/conf/callbacks/wandb_watch.yaml b/training/conf/callbacks/wandb_watch.yaml
deleted file mode 100644
index 511608c..0000000
--- a/training/conf/callbacks/wandb_watch.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-watch_model:
- _target_: callbacks.wandb_callbacks.WatchModel
- log: all
- log_freq: 100
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 9ed366f..a58891b 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,7 +1,7 @@
# @package _global_
defaults:
- - callbacks: wandb_htr
+ - callbacks: htr
- criterion: label_smoothing
- datamodule: iam_extended_paragraphs
- hydra: default
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index a0ffe56..6fa8206 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -2,7 +2,6 @@ _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
batch_size: 4
num_workers: 12
train_fraction: 0.8
-augment: true
pin_memory: false
-word_pieces: true
-resize: null
+transform: transform/paragraphs.yaml
+test_transform: transform/paragraphs.yaml
diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml
index d3c2af8..36e7093 100644
--- a/training/conf/datamodule/iam_lines.yaml
+++ b/training/conf/datamodule/iam_lines.yaml
@@ -2,6 +2,6 @@ _target_: text_recognizer.data.iam_lines.IAMLines
batch_size: 8
num_workers: 12
train_fraction: 0.8
-augment: true
pin_memory: false
-# word_pieces: true
+transform: transform/iam_lines.yaml
+test_transform: test_transform/iam_lines.yaml
diff --git a/training/conf/experiment/cnn_htr_char_lines.yaml b/training/conf/experiment/cnn_htr_char_lines.yaml
index 0d62a73..53f6d91 100644
--- a/training/conf/experiment/cnn_htr_char_lines.yaml
+++ b/training/conf/experiment/cnn_htr_char_lines.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /criterion: null
@@ -19,11 +17,10 @@ criterion:
_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
smoothing: 0.1
ignore_index: *ignore_index
- # _target_: torch.nn.CrossEntropyLoss
- # ignore_index: *ignore_index
mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
+ mapping: &mapping
+ _target_: text_recognizer.data.emnist_mapping.EmnistMapping
callbacks:
stochastic_weight_averaging:
@@ -73,6 +70,7 @@ datamodule:
augment: true
pin_memory: true
word_pieces: false
+ <<: *mapping
network:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
@@ -80,6 +78,7 @@ network:
hidden_dim: &hidden_dim 128
encoder_dim: 1280
dropout_rate: 0.2
+ <<: *mapping
num_classes: *num_classes
pad_index: *ignore_index
encoder:
diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml
index 79075cd..f467b74 100644
--- a/training/conf/experiment/cnn_htr_wp_lines.yaml
+++ b/training/conf/experiment/cnn_htr_wp_lines.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /criterion: null
diff --git a/training/conf/experiment/cnn_transformer_paragraphs.yaml b/training/conf/experiment/cnn_transformer_paragraphs.yaml
index 8feb1bc..910d408 100644
--- a/training/conf/experiment/cnn_transformer_paragraphs.yaml
+++ b/training/conf/experiment/cnn_transformer_paragraphs.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /criterion: null
diff --git a/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml b/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml
index 1c9bba1..499a609 100644
--- a/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml
+++ b/training/conf/experiment/cnn_transformer_paragraphs_wp.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /criterion: null
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 572c320..98f3346 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /network: vqvae
- override /criterion: null
diff --git a/training/conf/experiment/vqgan_htr_char.yaml b/training/conf/experiment/vqgan_htr_char.yaml
index 426524f..af3fa40 100644
--- a/training/conf/experiment/vqgan_htr_char.yaml
+++ b/training/conf/experiment/vqgan_htr_char.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /network: null
diff --git a/training/conf/experiment/vqgan_htr_char_iam_lines.yaml b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml
index 9f4791f..27fdfda 100644
--- a/training/conf/experiment/vqgan_htr_char_iam_lines.yaml
+++ b/training/conf/experiment/vqgan_htr_char_iam_lines.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /mapping: null
- override /criterion: null
diff --git a/training/conf/experiment/vqgan_iam_lines.yaml b/training/conf/experiment/vqgan_iam_lines.yaml
index 8bdf415..890948c 100644
--- a/training/conf/experiment/vqgan_iam_lines.yaml
+++ b/training/conf/experiment/vqgan_iam_lines.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /network: null
- override /criterion: null
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index 6e42690..d069aef 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -1,5 +1,3 @@
-# @package _global_
-
defaults:
- override /network: vqvae
- override /criterion: mse
diff --git a/training/conf/experiment/vqvae_pixelcnn.yaml b/training/conf/experiment/vqvae_pixelcnn.yaml
deleted file mode 100644
index 4fae782..0000000
--- a/training/conf/experiment/vqvae_pixelcnn.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-# @package _global_
-
-defaults:
- - override /network: vqvae_pixelcnn
- - override /criterion: mae
- - override /model: lit_vqvae
- - override /callbacks: wandb_vae
- - override /lr_schedulers:
- - cosine_annealing
-
-trainer:
- max_epochs: 256
- # gradient_clip_val: 0.25
-
-datamodule:
- batch_size: 8
-
-# lr_scheduler:
- # epochs: 64
- # steps_per_epoch: 1245
-
-# optimizer:
- # lr: 1.0e-3
-
diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml
index 14e966b..41a26a3 100644
--- a/training/conf/mapping/characters.yaml
+++ b/training/conf/mapping/characters.yaml
@@ -1,2 +1,2 @@
-_target_: text_recognizer.data.emnist_mapping.EmnistMapping
+_target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
extra_symbols: [ "\n" ]
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 48384f5..ca8dd9c 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
+_target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping
num_features: 1000
tokens: iamdb_1kwp_tokens_1000.txt
lexicon: iamdb_1kwp_lex_1000.txt