summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-08 08:41:40 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-06-08 08:41:40 +0200
commit80de5c023cb7740022a9b13b5f09b71d9e85f8e7 (patch)
tree6e32c4b4a7423a96f9d7051450a88b4c8e13858b /training/conf/experiment
parent7b660c13ce3c0edeace1107838e62c559bc6f078 (diff)
Add conformer lines experiment
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conformer_lines.yaml123
1 files changed, 123 insertions, 0 deletions
diff --git a/training/conf/experiment/conformer_lines.yaml b/training/conf/experiment/conformer_lines.yaml
new file mode 100644
index 0000000..c3f4ea5
--- /dev/null
+++ b/training/conf/experiment/conformer_lines.yaml
@@ -0,0 +1,123 @@
+# @package _global_
+
+defaults:
+ - override /mapping: null
+ - override /criterion: ctc
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: null
+ - override /model: null
+ - override /lr_schedulers: null
+ - override /optimizers: null
+
+epochs: &epochs 200
+num_classes: &num_classes 57
+max_output_len: &max_output_len 762
+summary: [[1, 57, 1024]]
+
+mapping: &mapping
+ mapping:
+ _target_: text_recognizer.data.mappings.EmnistMapping
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizers:
+ radam:
+ _target_: torch.optim.RAdam
+ lr: 3.0e-4
+ betas: [0.9, 0.999]
+ weight_decay: 0
+ eps: 1.0e-8
+ parameters: network
+
+lr_schedulers:
+ network:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.5
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/loss
+
+datamodule:
+ batch_size: 8
+ num_workers: 12
+ train_fraction: 0.9
+ pin_memory: true
+ << : *mapping
+
+network:
+ _target_: text_recognizer.networks.conformer.Conformer
+ depth: 16
+ num_classes: *num_classes
+ dim: &dim 128
+ block:
+ _target_: text_recognizer.networks.conformer.ConformerBlock
+ dim: *dim
+ attn:
+ _target_: text_recognizer.networks.conformer.Attention
+ dim: *dim
+ heads: 8
+ dim_head: 64
+ mult: 4
+ ff:
+ _target_: text_recognizer.networks.conformer.Feedforward
+ dim: *dim
+ expansion_factor: 4
+ dropout: 0.1
+ conv:
+ _target_: text_recognizer.networks.conformer.ConformerConv
+ dim: *dim
+ expansion_factor: 2
+ kernel_size: 31
+ dropout: 0.1
+ subsampler:
+ _target_: text_recognizer.networks.conformer.Subsampler
+ pixel_pos_embedding:
+ _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ dim: *dim
+ shape: [6, 127]
+ channels: *dim
+ depth: 3
+ dropout: 0.1
+
+model:
+ _target_: text_recognizer.models.conformer.LitConformer
+ <<: *mapping
+ max_output_len: *max_output_len
+ start_token: <s>
+ end_token: <e>
+ pad_token: <p>
+ blank_token: <b>
+
+trainer:
+ _target_: pytorch_lightning.Trainer
+ stochastic_weight_avg: true
+ auto_scale_batch_size: binsearch
+ auto_lr_find: false
+ gradient_clip_val: 0.5
+ fast_dev_run: false
+ gpus: 1
+ precision: 16
+ max_epochs: *epochs
+ terminate_on_nan: true
+ weights_summary: null
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
+ resume_from_checkpoint: null
+ accumulate_grad_batches: 1
+ overfit_batches: 0