summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 01:15:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 01:15:23 +0200
commit2b743f2f9046a2e2930647d234dc7392d71efa66 (patch)
tree237b5729cfbae90877a44baa52ac2ebaec388f79 /src/training
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
Fixed some bash scripts.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/line_ctc_experiment.yml36
-rw-r--r--src/training/run_experiment.py2
2 files changed, 19 insertions, 19 deletions
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
index 432d1cc..337c830 100644
--- a/src/training/experiments/line_ctc_experiment.yml
+++ b/src/training/experiments/line_ctc_experiment.yml
@@ -1,7 +1,7 @@
experiment_group: Lines Experiments
experiments:
- train_args:
- batch_size: 42
+ batch_size: 64
max_epochs: &max_epochs 32
dataset:
type: IamLinesDataset
@@ -17,18 +17,18 @@ experiments:
network:
type: LineRecurrentNetwork
args:
+ # backbone: ResidualNetwork
+ # backbone_args:
+ # in_channels: 1
+ # num_classes: 64 # Embedding
+ # depths: [2,2]
+ # block_sizes: [32, 64]
+ # activation: selu
+ # stn: false
backbone: ResidualNetwork
backbone_args:
- in_channels: 1
- num_classes: 64 # Embedding
- depths: [2,2]
- block_sizes: [32,64]
- activation: selu
- stn: false
- # encoder: ResidualNetwork
- # encoder_args:
- # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt
- # freeze: false
+ pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0920_010806/model/best.pt
+ freeze: false
flatten: false
input_size: 64
hidden_size: 64
@@ -67,20 +67,20 @@ experiments:
# args:
# T_max: *max_epochs
swa_args:
- start: 24
+ start: 48
lr: 5.e-2
- callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping]
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping]
callback_args:
Checkpoint:
monitor: val_loss
mode: min
ProgressBar:
epochs: *max_epochs
- # EarlyStopping:
- # monitor: val_loss
- # min_delta: 0.0
- # patience: 10
- # mode: min
+ EarlyStopping:
+ monitor: val_loss
+ min_delta: 0.0
+ patience: 10
+ mode: min
WandbCallback:
log_batch_frequency: 10
WandbImageLogger:
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index a347d9f..cc882ad 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -116,7 +116,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
# Learning rate scheduler
lr_scheduler_ = None
lr_scheduler_args = None
- if experiment_config["lr_scheduler"] is not None:
+ if "lr_scheduler" in experiment_config:
lr_scheduler_ = getattr(
torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
)