summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/default_config_emnist.yml1
-rw-r--r--src/training/experiments/embedding_experiment.yml58
-rw-r--r--src/training/experiments/line_ctc_experiment.yml3
-rw-r--r--src/training/experiments/sample_experiment.yml1
-rw-r--r--src/training/run_experiment.py3
-rw-r--r--src/training/trainer/callbacks/base.py2
6 files changed, 65 insertions, 3 deletions
diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml
index 12a0a9d..bf2ed0a 100644
--- a/src/training/experiments/default_config_emnist.yml
+++ b/src/training/experiments/default_config_emnist.yml
@@ -66,4 +66,5 @@ callback_args:
null
verbosity: 1 # 0, 1, 2
resume_experiment: null
+train: true
validation_metric: val_accuracy
diff --git a/src/training/experiments/embedding_experiment.yml b/src/training/experiments/embedding_experiment.yml
new file mode 100644
index 0000000..e674c26
--- /dev/null
+++ b/src/training/experiments/embedding_experiment.yml
@@ -0,0 +1,58 @@
+experiment_group: Embedding Experiments
+experiments:
+ - train_args:
+ batch_size: 256
+ max_epochs: &max_epochs 8
+ dataset:
+ type: EmnistDataset
+ args:
+ sample_to_balance: true
+ subsample_fraction: null
+ transform: null
+ target_transform: null
+ seed: 4711
+ train_args:
+ num_workers: 8
+ train_fraction: 0.85
+ model: CharacterModel
+ metrics: []
+ network:
+ type: ResidualNetwork
+ args:
+ in_channels: 1
+ num_classes: 64 # Embedding
+ depths: [2,2]
+ block_sizes: [32, 64]
+ activation: selu
+ stn: false
+ criterion:
+ type: EmbeddingLoss
+ args:
+ margin: 0.2
+ type_of_triplets: semihard
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-02
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ weight_decay: 5.e-4
+ amsgrad: false
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: *max_epochs
+ callbacks: [Checkpoint, ProgressBar, WandbCallback]
+ callback_args:
+ Checkpoint:
+ monitor: val_loss
+ mode: min
+ ProgressBar:
+ epochs: *max_epochs
+ WandbCallback:
+ log_batch_frequency: 10
+ verbosity: 1 # 0, 1, 2
+ resume_experiment: null
+ train: true
+ test: true
+ test_metric: mean_average_precision_at_r
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
index f309990..ef97527 100644
--- a/src/training/experiments/line_ctc_experiment.yml
+++ b/src/training/experiments/line_ctc_experiment.yml
@@ -27,7 +27,7 @@ experiments:
# stn: false
backbone: ResidualNetwork
backbone_args:
- pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0920_010806/model/best.pt
+ pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0920_025816/model/best.pt
freeze: false
flatten: false
input_size: 64
@@ -87,5 +87,6 @@ experiments:
num_examples: 6
verbosity: 1 # 0, 1, 2
resume_experiment: null
+ train: true
test: true
test_metric: test_cer
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 8664a15..a073a87 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -95,5 +95,6 @@ experiments:
use_transpose: true
verbosity: 0 # 0, 1, 2
resume_experiment: null
+ train: true
test: true
test_metric: test_accuracy
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index cc882ad..9d45841 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -273,7 +273,8 @@ def run_experiment(
)
# Train the model.
- trainer.fit(model)
+ if experiment_config["train"]:
+ trainer.fit(model)
# Run inference over test set.
if experiment_config["test"]:
diff --git a/src/training/trainer/callbacks/base.py b/src/training/trainer/callbacks/base.py
index 8c7b085..f81fc1f 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -92,7 +92,7 @@ class CallbackList:
def append(self, callback: Type[Callback]) -> None:
"""Append new callback to callback list."""
- self.callbacks.append(callback)
+ self._callbacks.append(callback)
def on_fit_begin(self) -> None:
"""Called when fit begins."""