summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/vq_transformer_lines.yaml81
-rw-r--r--training/conf/model/lit_vq_transformer.yaml5
-rw-r--r--training/conf/network/vq_transformer.yaml65
3 files changed, 151 insertions, 0 deletions
diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml
new file mode 100644
index 0000000..dbd8a3b
--- /dev/null
+++ b/training/conf/experiment/vq_transformer_lines.yaml
@@ -0,0 +1,81 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: vq_transformer
+ - override /model: lit_vq_transformer
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines]
+epochs: &epochs 200
+ignore_index: &ignore_index 3
+num_classes: &num_classes 57
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024], [1, 89]]
+
+logger:
+ wandb:
+ tags: ${tags}
+ # id: 342qvr1p
+
+criterion:
+ ignore_index: *ignore_index
+ label_smoothing: 0.05
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: torch.optim.RAdam
+ lr: 3.0e-4
+ betas: [0.9, 0.999]
+ weight_decay: 0
+ eps: 1.0e-8
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ input_dims: [1, 1, 56, 1024]
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ depth: 5
+ decoder:
+ depth: 6
+ pixel_embedding:
+ shape: [1, 127]
+
+model:
+ max_output_len: *max_output_len
+ vq_loss_weight: 0.1
+
+trainer:
+ gradient_clip_val: 1.0
+ max_epochs: *epochs
+ accumulate_grad_batches: 1
+ # resume_from_checkpoint: /home/aktersnurra/projects/text-recognizer/training/logs/runs/2022-06-27/00-37-40/checkpoints/last.ckpt
diff --git a/training/conf/model/lit_vq_transformer.yaml b/training/conf/model/lit_vq_transformer.yaml
new file mode 100644
index 0000000..4173151
--- /dev/null
+++ b/training/conf/model/lit_vq_transformer.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.models.LitVqTransformer
+max_output_len: 682
+start_token: <s>
+end_token: <e>
+pad_token: <p>
diff --git a/training/conf/network/vq_transformer.yaml b/training/conf/network/vq_transformer.yaml
new file mode 100644
index 0000000..d62a4b7
--- /dev/null
+++ b/training/conf/network/vq_transformer.yaml
@@ -0,0 +1,65 @@
+_target_: text_recognizer.networks.VqTransformer
+input_dims: [1, 1, 576, 640]
+hidden_dim: &hidden_dim 144
+num_classes: 58
+pad_index: 3
+encoder:
+ _target_: text_recognizer.networks.EfficientNet
+ arch: b0
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ depth: 5
+ out_channels: *hidden_dim
+decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 8
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.4
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 8
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+pixel_embedding:
+ _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: [18, 79]
+quantizer:
+ _target_: text_recognizer.networks.quantizer.VectorQuantizer
+ input_dim: *hidden_dim
+ codebook:
+ _target_: text_recognizer.networks.quantizer.CosineSimilarityCodebook
+ dim: 16
+ codebook_size: 64
+ kmeans_init: true
+ kmeans_iters: 10
+ decay: 0.8
+ eps: 1.0e-5
+ threshold_dead: 2
+ temperature: 0.0
+ commitment: 0.25
+ ort_reg_weight: 10
+ ort_reg_max_codes: 64