From 3b06ef615a8db67a03927576e0c12fbfb2501f5f Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 14 Sep 2020 22:15:47 +0200 Subject: Fixed CTC loss. --- src/training/experiments/default_config_emnist.yml | 69 +++++++++++++++ .../experiments/iam_line_ctc_experiment.yml | 94 +++++++++++++++++++++ src/training/experiments/line_ctc_experiment.yml | 98 ++++++++++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 src/training/experiments/default_config_emnist.yml create mode 100644 src/training/experiments/iam_line_ctc_experiment.yml create mode 100644 src/training/experiments/line_ctc_experiment.yml (limited to 'src/training/experiments') diff --git a/src/training/experiments/default_config_emnist.yml b/src/training/experiments/default_config_emnist.yml new file mode 100644 index 0000000..12a0a9d --- /dev/null +++ b/src/training/experiments/default_config_emnist.yml @@ -0,0 +1,69 @@ +dataset: EmnistDataset +dataset_args: + sample_to_balance: true + subsample_fraction: 0.33 + transform: null + target_transform: null + seed: 4711 + +data_loader_args: + splits: [train, val] + shuffle: true + num_workers: 8 + cuda: true + +model: CharacterModel +metrics: [accuracy] + +network_args: + in_channels: 1 + num_classes: 80 + depths: [2] + block_sizes: [256] + +train_args: + batch_size: 256 + epochs: 5 + +criterion: CrossEntropyLoss +criterion_args: + weight: null + ignore_index: -100 + reduction: mean + +optimizer: AdamW +optimizer_args: + lr: 1.e-03 + betas: [0.9, 0.999] + eps: 1.e-08 + # weight_decay: 5.e-4 + amsgrad: false + +lr_scheduler: OneCycleLR +lr_scheduler_args: + max_lr: 1.e-03 + epochs: 5 + anneal_strategy: linear + + +callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] +callback_args: + Checkpoint: + monitor: val_accuracy + ProgressBar: + epochs: 5 + log_batch_frequency: 100 + EarlyStopping: + monitor: val_loss + min_delta: 0.0 + patience: 3 + mode: min + WandbCallback: + log_batch_frequency: 10 + WandbImageLogger: + num_examples: 4 + OneCycleLR: + null +verbosity: 1 # 0, 1, 2 +resume_experiment: null +validation_metric: val_accuracy diff --git a/src/training/experiments/iam_line_ctc_experiment.yml b/src/training/experiments/iam_line_ctc_experiment.yml new file mode 100644 index 0000000..141c74e --- /dev/null +++ b/src/training/experiments/iam_line_ctc_experiment.yml @@ -0,0 +1,94 @@ +experiment_group: Sample Experiments +experiments: + - train_args: + batch_size: 24 + max_epochs: 128 + dataset: + type: IamLinesDataset + args: + subsample_fraction: null + transform: null + target_transform: null + train_args: + num_workers: 6 + train_fraction: 0.85 + model: LineCTCModel + metrics: [cer, wer] + network: + type: LineRecurrentNetwork + args: + # encoder: ResidualNetworkEncoder + # encoder_args: + # in_channels: 1 + # num_classes: 80 + # depths: [2, 2] + # block_sizes: [128, 128] + # activation: SELU + # stn: false + encoder: WideResidualNetwork + encoder_args: + in_channels: 1 + num_classes: 80 + depth: 16 + num_layers: 4 + width_factor: 2 + dropout_rate: 0.2 + activation: selu + use_decoder: false + flatten: true + input_size: 256 + hidden_size: 128 + num_layers: 2 + num_classes: 80 + patch_size: [28, 14] + stride: [1, 5] + criterion: + type: CTCLoss + args: + blank: 79 + optimizer: + type: AdamW + args: + lr: 1.e-03 + betas: [0.9, 0.999] + eps: 1.e-08 + weight_decay: false + amsgrad: false + # lr_scheduler: + # type: OneCycleLR + # args: + # max_lr: 1.e-02 + # epochs: null + # anneal_strategy: linear + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: null + swa_args: + start: 75 + lr: 5.e-2 + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] + callback_args: + Checkpoint: + monitor: val_loss + mode: min + ProgressBar: + epochs: null + # log_batch_frequency: 100 + # EarlyStopping: + # monitor: val_loss + # min_delta: 0.0 + # patience: 7 + # mode: min + WandbCallback: + log_batch_frequency: 10 + WandbImageLogger: + num_examples: 6 + # OneCycleLR: + # null + SWA: + null + verbosity: 1 # 0, 1, 2 + resume_experiment: null + test: true + test_metric: test_cer diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml new file mode 100644 index 0000000..c21c6a2 --- /dev/null +++ b/src/training/experiments/line_ctc_experiment.yml @@ -0,0 +1,98 @@ +experiment_group: Sample Experiments +experiments: + - train_args: + batch_size: 64 + max_epochs: 32 + dataset: + type: EmnistLinesDataset + args: + subsample_fraction: 0.33 + max_length: 34 + min_overlap: 0 + max_overlap: 0.33 + num_samples: 10000 + seed: 4711 + blank: true + train_args: + num_workers: 6 + train_fraction: 0.85 + model: LineCTCModel + metrics: [cer, wer] + network: + type: LineRecurrentNetwork + args: + # encoder: ResidualNetworkEncoder + # encoder_args: + # in_channels: 1 + # num_classes: 81 + # depths: [2, 2] + # block_sizes: [64, 128] + # activation: SELU + # stn: false + encoder: WideResidualNetwork + encoder_args: + in_channels: 1 + num_classes: 81 + depth: 16 + num_layers: 4 + width_factor: 2 + dropout_rate: 0.2 + activation: selu + use_decoder: false + flatten: true + input_size: 256 + hidden_size: 128 + num_layers: 2 + num_classes: 81 + patch_size: [28, 14] + stride: [1, 5] + criterion: + type: CTCLoss + args: + blank: 80 + optimizer: + type: AdamW + args: + lr: 1.e-02 + betas: [0.9, 0.999] + eps: 1.e-08 + weight_decay: 5.e-4 + amsgrad: false + # lr_scheduler: + # type: OneCycleLR + # args: + # max_lr: 1.e-03 + # epochs: null + # anneal_strategy: linear + lr_scheduler: + type: CosineAnnealingLR + args: + T_max: null + swa_args: + start: 4 + lr: 5.e-2 + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] + callback_args: + Checkpoint: + monitor: val_loss + mode: min + ProgressBar: + epochs: null + log_batch_frequency: 100 + # EarlyStopping: + # monitor: val_loss + # min_delta: 0.0 + # patience: 5 + # mode: min + WandbCallback: + log_batch_frequency: 10 + WandbImageLogger: + num_examples: 6 + # OneCycleLR: + # null + SWA: + null + verbosity: 1 # 0, 1, 2 + resume_experiment: null + test: true + test_metric: test_cer -- cgit v1.2.3-70-g09d2