diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/experiment/vq_transformer_lines.yaml | 149 | ||||
-rw-r--r-- | training/conf/network/quantizer.yaml | 12 |
2 files changed, 161 insertions, 0 deletions
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml new file mode 100644 index 0000000..bbe1178 --- /dev/null +++ b/training/conf/experiment/vq_transformer_lines.yaml @@ -0,0 +1,149 @@ +# @package _global_ + +defaults: + - override /mapping: null + - override /criterion: cross_entropy + - override /callbacks: htr + - override /datamodule: iam_lines + - override /network: null + - override /model: null + - override /lr_schedulers: null + - override /optimizers: null + +epochs: &epochs 512 +ignore_index: &ignore_index 3 +num_classes: &num_classes 57 +max_output_len: &max_output_len 89 +summary: [[1, 1, 56, 1024], [1, 89]] + +criterion: + ignore_index: *ignore_index + +mapping: &mapping + mapping: + _target_: text_recognizer.data.mappings.emnist.EmnistMapping + +callbacks: + stochastic_weight_averaging: + _target_: pytorch_lightning.callbacks.StochasticWeightAveraging + swa_epoch_start: 0.75 + swa_lrs: 1.0e-5 + annealing_epochs: 10 + annealing_strategy: cos + device: null + +optimizers: + madgrad: + _target_: madgrad.MADGRAD + lr: 3.0e-4 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + parameters: network + +lr_schedulers: + network: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: *epochs + eta_min: 1.0e-5 + last_epoch: -1 + interval: epoch + monitor: val/loss + +datamodule: + batch_size: 16 + num_workers: 12 + train_fraction: 0.9 + pin_memory: true + << : *mapping + +rotary_embedding: &rotary_embedding + rotary_embedding: + _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding + dim: 64 + +attn: &attn + dim: &hidden_dim 512 + num_heads: 4 + dim_head: 64 + dropout_rate: &dropout_rate 0.4 + +network: + _target_: text_recognizer.networks.vq_transformer.VqTransformer + input_dims: [1, 56, 1024] + hidden_dim: *hidden_dim + num_classes: *num_classes + pad_index: *ignore_index + encoder: + _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet + arch: b1 + stochastic_dropout_rate: 0.2 + bn_momentum: 0.99 + bn_eps: 1.0e-3 + decoder: + depth: 6 + _target_: text_recognizer.networks.transformer.layers.Decoder + self_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: true + << : *rotary_embedding + cross_attn: + _target_: text_recognizer.networks.transformer.attention.Attention + << : *attn + causal: false + norm: + _target_: text_recognizer.networks.transformer.norm.ScaleNorm + normalized_shape: *hidden_dim + ff: + _target_: text_recognizer.networks.transformer.mlp.FeedForward + dim: *hidden_dim + dim_out: null + expansion_factor: 4 + glu: true + dropout_rate: *dropout_rate + pre_norm: true + pixel_pos_embedding: + _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding + dim: *hidden_dim + shape: [1, 32] + quantizer: + _target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer + input_dim: 512 + codebook: + _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook + dim: 16 + codebook_size: 4096 + kmeans_init: true + kmeans_iters: 10 + decay: 0.8 + eps: 1.0e-5 + threshold_dead: 2 + commitment: 1.0 + +model: + _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel + << : *mapping + max_output_len: *max_output_len + start_token: <s> + end_token: <e> + pad_token: <p> + +trainer: + _target_: pytorch_lightning.Trainer + stochastic_weight_avg: true + auto_scale_batch_size: binsearch + auto_lr_find: false + gradient_clip_val: 0.5 + fast_dev_run: false + gpus: 1 + precision: 16 + max_epochs: *epochs + terminate_on_nan: true + weights_summary: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + resume_from_checkpoint: null + accumulate_grad_batches: 1 + overfit_batches: 0 diff --git a/training/conf/network/quantizer.yaml b/training/conf/network/quantizer.yaml new file mode 100644 index 0000000..827a247 --- /dev/null +++ b/training/conf/network/quantizer.yaml @@ -0,0 +1,12 @@ +_target_: text_recognizer.networks.quantizer.quantizer.VectorQuantizer +input_dim: 192 +codebook: + _target_: text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook + dim: 16 + codebook_size: 2048 + kmeans_init: true + kmeans_iters: 10 + decay: 0.8 + eps: 1.0e-5 + threshold_dead: 2 +commitment: 1.0 |