summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-14 22:15:47 +0200
commit3b06ef615a8db67a03927576e0c12fbfb2501f5f (patch)
treee1c2b1289971c8480327408de46152481e99b539 /src/training
parent2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (diff)
Fixed CTC loss.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/default_config_emnist.yml69
-rw-r--r--src/training/experiments/iam_line_ctc_experiment.yml94
-rw-r--r--src/training/experiments/line_ctc_experiment.yml98
-rw-r--r--src/training/run_experiment.py2
4 files changed, 262 insertions, 1 deletions
diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml
new file mode 100644
index 0000000..12a0a9d
--- /dev/null
+++ b/src/training/experiments/default_config_emnist.yml
@@ -0,0 +1,69 @@
+dataset: EmnistDataset
+dataset_args:
+ sample_to_balance: true
+ subsample_fraction: 0.33
+ transform: null
+ target_transform: null
+ seed: 4711
+
+data_loader_args:
+ splits: [train, val]
+ shuffle: true
+ num_workers: 8
+ cuda: true
+
+model: CharacterModel
+metrics: [accuracy]
+
+network_args:
+ in_channels: 1
+ num_classes: 80
+ depths: [2]
+ block_sizes: [256]
+
+train_args:
+ batch_size: 256
+ epochs: 5
+
+criterion: CrossEntropyLoss
+criterion_args:
+ weight: null
+ ignore_index: -100
+ reduction: mean
+
+optimizer: AdamW
+optimizer_args:
+ lr: 1.e-03
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ # weight_decay: 5.e-4
+ amsgrad: false
+
+lr_scheduler: OneCycleLR
+lr_scheduler_args:
+ max_lr: 1.e-03
+ epochs: 5
+ anneal_strategy: linear
+
+
+callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+callback_args:
+ Checkpoint:
+ monitor: val_accuracy
+ ProgressBar:
+ epochs: 5
+ log_batch_frequency: 100
+ EarlyStopping:
+ monitor: val_loss
+ min_delta: 0.0
+ patience: 3
+ mode: min
+ WandbCallback:
+ log_batch_frequency: 10
+ WandbImageLogger:
+ num_examples: 4
+ OneCycleLR:
+ null
+verbosity: 1 # 0, 1, 2
+resume_experiment: null
+validation_metric: val_accuracy
diff --git a/src/training/experiments/iam_line_ctc_experiment.yml b/src/training/experiments/iam_line_ctc_experiment.yml
new file mode 100644
index 0000000..141c74e
--- /dev/null
+++ b/src/training/experiments/iam_line_ctc_experiment.yml
@@ -0,0 +1,94 @@
+experiment_group: Sample Experiments
+experiments:
+ - train_args:
+ batch_size: 24
+ max_epochs: 128
+ dataset:
+ type: IamLinesDataset
+ args:
+ subsample_fraction: null
+ transform: null
+ target_transform: null
+ train_args:
+ num_workers: 6
+ train_fraction: 0.85
+ model: LineCTCModel
+ metrics: [cer, wer]
+ network:
+ type: LineRecurrentNetwork
+ args:
+ # encoder: ResidualNetworkEncoder
+ # encoder_args:
+ # in_channels: 1
+ # num_classes: 80
+ # depths: [2, 2]
+ # block_sizes: [128, 128]
+ # activation: SELU
+ # stn: false
+ encoder: WideResidualNetwork
+ encoder_args:
+ in_channels: 1
+ num_classes: 80
+ depth: 16
+ num_layers: 4
+ width_factor: 2
+ dropout_rate: 0.2
+ activation: selu
+ use_decoder: false
+ flatten: true
+ input_size: 256
+ hidden_size: 128
+ num_layers: 2
+ num_classes: 80
+ patch_size: [28, 14]
+ stride: [1, 5]
+ criterion:
+ type: CTCLoss
+ args:
+ blank: 79
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-03
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ weight_decay: false
+ amsgrad: false
+ # lr_scheduler:
+ # type: OneCycleLR
+ # args:
+ # max_lr: 1.e-02
+ # epochs: null
+ # anneal_strategy: linear
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: null
+ swa_args:
+ start: 75
+ lr: 5.e-2
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR]
+ callback_args:
+ Checkpoint:
+ monitor: val_loss
+ mode: min
+ ProgressBar:
+ epochs: null
+ # log_batch_frequency: 100
+ # EarlyStopping:
+ # monitor: val_loss
+ # min_delta: 0.0
+ # patience: 7
+ # mode: min
+ WandbCallback:
+ log_batch_frequency: 10
+ WandbImageLogger:
+ num_examples: 6
+ # OneCycleLR:
+ # null
+ SWA:
+ null
+ verbosity: 1 # 0, 1, 2
+ resume_experiment: null
+ test: true
+ test_metric: test_cer
diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml
new file mode 100644
index 0000000..c21c6a2
--- /dev/null
+++ b/src/training/experiments/line_ctc_experiment.yml
@@ -0,0 +1,98 @@
+experiment_group: Sample Experiments
+experiments:
+ - train_args:
+ batch_size: 64
+ max_epochs: 32
+ dataset:
+ type: EmnistLinesDataset
+ args:
+ subsample_fraction: 0.33
+ max_length: 34
+ min_overlap: 0
+ max_overlap: 0.33
+ num_samples: 10000
+ seed: 4711
+ blank: true
+ train_args:
+ num_workers: 6
+ train_fraction: 0.85
+ model: LineCTCModel
+ metrics: [cer, wer]
+ network:
+ type: LineRecurrentNetwork
+ args:
+ # encoder: ResidualNetworkEncoder
+ # encoder_args:
+ # in_channels: 1
+ # num_classes: 81
+ # depths: [2, 2]
+ # block_sizes: [64, 128]
+ # activation: SELU
+ # stn: false
+ encoder: WideResidualNetwork
+ encoder_args:
+ in_channels: 1
+ num_classes: 81
+ depth: 16
+ num_layers: 4
+ width_factor: 2
+ dropout_rate: 0.2
+ activation: selu
+ use_decoder: false
+ flatten: true
+ input_size: 256
+ hidden_size: 128
+ num_layers: 2
+ num_classes: 81
+ patch_size: [28, 14]
+ stride: [1, 5]
+ criterion:
+ type: CTCLoss
+ args:
+ blank: 80
+ optimizer:
+ type: AdamW
+ args:
+ lr: 1.e-02
+ betas: [0.9, 0.999]
+ eps: 1.e-08
+ weight_decay: 5.e-4
+ amsgrad: false
+ # lr_scheduler:
+ # type: OneCycleLR
+ # args:
+ # max_lr: 1.e-03
+ # epochs: null
+ # anneal_strategy: linear
+ lr_scheduler:
+ type: CosineAnnealingLR
+ args:
+ T_max: null
+ swa_args:
+ start: 4
+ lr: 5.e-2
+ callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR]
+ callback_args:
+ Checkpoint:
+ monitor: val_loss
+ mode: min
+ ProgressBar:
+ epochs: null
+ log_batch_frequency: 100
+ # EarlyStopping:
+ # monitor: val_loss
+ # min_delta: 0.0
+ # patience: 5
+ # mode: min
+ WandbCallback:
+ log_batch_frequency: 10
+ WandbImageLogger:
+ num_examples: 6
+ # OneCycleLR:
+ # null
+ SWA:
+ null
+ verbosity: 1 # 0, 1, 2
+ resume_experiment: null
+ test: true
+ test_metric: test_cer
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 4317d66..286b0c6 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -211,7 +211,7 @@ def run_experiment(
experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
) -> None:
"""Runs an experiment."""
- logger.info(f"Experiment config: {json.dumps(experiment_config)}")
+ logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}")
# Create new experiment.
experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)