diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/experiment/conformer_lines.yaml | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/training/conf/experiment/conformer_lines.yaml b/training/conf/experiment/conformer_lines.yaml new file mode 100644 index 0000000..c3f4ea5 --- /dev/null +++ b/training/conf/experiment/conformer_lines.yaml @@ -0,0 +1,123 @@ +# @package _global_ + +defaults: + - override /mapping: null + - override /criterion: ctc + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: null + - override /lr_schedulers: null + - override /optimizers: null + +epochs: &epochs 200 +num_classes: &num_classes 57 +max_output_len: &max_output_len 762 +summary: [[1, 57, 1024]] + +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.EmnistMapping + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizers: + radam: + _target_: torch.optim.RAdam + lr: 3.0e-4 + betas: [0.9, 0.999] + weight_decay: 0 + eps: 1.0e-8 + parameters: network + +lr_schedulers: + network: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.5 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/loss + +datamodule: + batch_size: 8 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping + +network: + _target_: text_recognizer.networks.conformer.Conformer + depth: 16 + num_classes: *num_classes + dim: &dim 128 + block: + _target_: text_recognizer.networks.conformer.ConformerBlock + dim: *dim + attn: + _target_: text_recognizer.networks.conformer.Attention + dim: *dim + heads: 8 + dim_head: 64 + mult: 4 + ff: + _target_: text_recognizer.networks.conformer.Feedforward + dim: *dim + expansion_factor: 4 + dropout: 0.1 + conv: + _target_: text_recognizer.networks.conformer.ConformerConv + dim: *dim + expansion_factor: 2 + kernel_size: 31 + dropout: 0.1 + subsampler: + _target_: text_recognizer.networks.conformer.Subsampler + pixel_pos_embedding: + _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + dim: *dim + shape: [6, 127] + channels: *dim + depth: 3 + dropout: 0.1 + +model: + _target_: text_recognizer.models.conformer.LitConformer + <<: *mapping + max_output_len: *max_output_len + start_token: <s> + end_token: <e> + pad_token: <p> + blank_token: <b> + +trainer: + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: true + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0.5 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 1 + overfit_batches: 0 |