summaryrefslogtreecommitdiff
path: root/src/training/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/experiments')
-rw-r--r--src/training/experiments/default_config_emnist.yml1
-rw-r--r--src/training/experiments/embedding_experiment.yml64
-rw-r--r--src/training/experiments/line_ctc_experiment.yml91
-rw-r--r--src/training/experiments/sample_experiment.yml1
4 files changed, 66 insertions, 91 deletions
diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml
index 12a0a9d..bf2ed0a 100644
--- a/src/training/experiments/default_config_emnist.yml
+++ b/src/training/experiments/default_config_emnist.yml
@@ -66,4 +66,5 @@ callback_args:
null
verbosity: 1 # 0, 1, 2
resume_experiment: null
+train: true
validation_metric: val_accuracy
diff --git a/src/training/experiments/embedding_experiment.yml b/src/training/experiments/embedding_experiment.yml
new file mode 100644
index 0000000..1e5f941
--- /dev/null
+++ b/src/training/experiments/embedding_experiment.yml
@@ -0,0 +1,64 @@
+experiment_group: Embedding Experiments
+experiments:
+ - train_args:
+ transformer_model: false
+ batch_size: &batch_size 256
+ max_epochs: &max_epochs 32
+ input_shape: [[1, 28, 28]]
+ dataset:
+ type: EmnistDataset
+ args:
+ sample_to_balance: true
+ subsample_fraction: null
+ transform: null
+ target_transform: null
+ seed: 4711
+ train_args:
+ num_workers: 8
+ train_fraction: 0.85
+ batch_size: *batch_size
+ model: CharacterModel
+ metrics: []
+ network:
+ type: DenseNet
+ args:
+ growth_rate: 4
+ block_config: [4, 4]
+ in_channels: 1
+ base_channels: 24
+ num_classes: 128
+ bn_size: 4
+ dropout_rate: 0.1
+ classifier: true
+ activation: elu
+ criterion:
+ type: EmbeddingLoss
+ args:
+ margin: 0.2
+ type_of_triplets: semihard
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-02
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ weight_decay: 5.e-4
+ amsgrad: false
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: *max_epochs
+ callbacks: [Checkpoint, ProgressBar, WandbCallback]
+ callback_args:
+ Checkpoint:
+ monitor: val_loss
+ mode: min
+ ProgressBar:
+ epochs: *max_epochs
+ WandbCallback:
+ log_batch_frequency: 10
+ verbosity: 1 # 0, 1, 2
+ resume_experiment: null
+ train: true
+ test: true
+ test_metric: mean_average_precision_at_r
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
deleted file mode 100644
index 432d1cc..0000000
--- a/src/training/experiments/line_ctc_experiment.yml
+++ /dev/null
@@ -1,91 +0,0 @@
-experiment_group: Lines Experiments
-experiments:
- - train_args:
- batch_size: 42
- max_epochs: &max_epochs 32
- dataset:
- type: IamLinesDataset
- args:
- subsample_fraction: null
- transform: null
- target_transform: null
- train_args:
- num_workers: 8
- train_fraction: 0.85
- model: LineCTCModel
- metrics: [cer, wer]
- network:
- type: LineRecurrentNetwork
- args:
- backbone: ResidualNetwork
- backbone_args:
- in_channels: 1
- num_classes: 64 # Embedding
- depths: [2,2]
- block_sizes: [32,64]
- activation: selu
- stn: false
- # encoder: ResidualNetwork
- # encoder_args:
- # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt
- # freeze: false
- flatten: false
- input_size: 64
- hidden_size: 64
- bidirectional: true
- num_layers: 2
- num_classes: 80
- patch_size: [28, 18]
- stride: [1, 4]
- criterion:
- type: CTCLoss
- args:
- blank: 79
- optimizer:
- type: AdamW
- args:
- lr: 1.e-02
- betas: [0.9, 0.999]
- eps: 1.e-08
- weight_decay: 5.e-4
- amsgrad: false
- lr_scheduler:
- type: OneCycleLR
- args:
- max_lr: 1.e-02
- epochs: *max_epochs
- anneal_strategy: cos
- pct_start: 0.475
- cycle_momentum: true
- base_momentum: 0.85
- max_momentum: 0.9
- div_factor: 10
- final_div_factor: 10000
- interval: step
- # lr_scheduler:
- # type: CosineAnnealingLR
- # args:
- # T_max: *max_epochs
- swa_args:
- start: 24
- lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping]
- callback_args:
- Checkpoint:
- monitor: val_loss
- mode: min
- ProgressBar:
- epochs: *max_epochs
- # EarlyStopping:
- # monitor: val_loss
- # min_delta: 0.0
- # patience: 10
- # mode: min
- WandbCallback:
- log_batch_frequency: 10
- WandbImageLogger:
- num_examples: 6
- verbosity: 1 # 0, 1, 2
- resume_experiment: null
- test: true
- test_metric: test_cer
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 8664a15..a073a87 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -95,5 +95,6 @@ experiments:
use_transpose: true
verbosity: 0 # 0, 1, 2
resume_experiment: null
+ train: true
test: true
test_metric: test_accuracy