From 80de5c023cb7740022a9b13b5f09b71d9e85f8e7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 8 Jun 2022 08:41:40 +0200 Subject: Add conformer lines experiment --- training/conf/experiment/conformer_lines.yaml | 123 ++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 training/conf/experiment/conformer_lines.yaml (limited to 'training/conf') diff --git a/training/conf/experiment/conformer_lines.yaml b/training/conf/experiment/conformer_lines.yaml new file mode 100644 index 0000000..c3f4ea5 --- /dev/null +++ b/training/conf/experiment/conformer_lines.yaml @@ -0,0 +1,123 @@ +# @package _global_ + +defaults: + - override /mapping: null + - override /criterion: ctc + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: null + - override /lr_schedulers: null + - override /optimizers: null + +epochs: &epochs 200 +num_classes: &num_classes 57 +max_output_len: &max_output_len 762 +summary: [[1, 57, 1024]] + +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.EmnistMapping + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizers: + radam: + _target_: torch.optim.RAdam + lr: 3.0e-4 + betas: [0.9, 0.999] + weight_decay: 0 + eps: 1.0e-8 + parameters: network + +lr_schedulers: + network: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + mode: min + factor: 0.5 + patience: 10 + threshold: 1.0e-4 + threshold_mode: rel + cooldown: 0 + min_lr: 1.0e-5 + eps: 1.0e-8 + verbose: false + interval: epoch + monitor: val/loss + +datamodule: + batch_size: 8 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping + +network: + _target_: text_recognizer.networks.conformer.Conformer + depth: 16 + num_classes: *num_classes + dim: &dim 128 + block: + _target_: text_recognizer.networks.conformer.ConformerBlock + dim: *dim + attn: + _target_: text_recognizer.networks.conformer.Attention + dim: *dim + heads: 8 + dim_head: 64 + mult: 4 + ff: + _target_: text_recognizer.networks.conformer.Feedforward + dim: *dim + expansion_factor: 4 + dropout: 0.1 + conv: + _target_: text_recognizer.networks.conformer.ConformerConv + dim: *dim + expansion_factor: 2 + kernel_size: 31 + dropout: 0.1 + subsampler: + _target_: text_recognizer.networks.conformer.Subsampler + pixel_pos_embedding: + _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding + dim: *dim + shape: [6, 127] + channels: *dim + depth: 3 + dropout: 0.1 + +model: + _target_: text_recognizer.models.conformer.LitConformer + <<: *mapping + max_output_len: *max_output_len + start_token: + end_token: + pad_token:

+ blank_token: + +trainer: + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: true + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0.5 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 1 + overfit_batches: 0 -- cgit v1.2.3-70-g09d2