summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml101
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml7
2 files changed, 75 insertions, 33 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index eb9bc9e..d4478cc 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -4,20 +4,26 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- - override /network: conv_transformer
+ - override /network: null
+ # - override /network: conv_transformer
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
-epochs: &epochs 512
+tags: [lines]
+epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 57
max_output_len: &max_output_len 89
summary: [[1, 1, 56, 1024], [1, 89]]
+logger:
+ wandb:
+ tags: ${tags}
+
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.05
+ # label_smoothing: 0.1
callbacks:
stochastic_weight_averaging:
@@ -29,30 +35,23 @@ callbacks:
device: null
optimizer:
- _target_: torch.optim.RAdam
+ _target_: adan_pytorch.Adan
lr: 3.0e-4
- betas: [0.9, 0.999]
- weight_decay: 0
- eps: 1.0e-8
- parameters: network
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
lr_scheduler:
- _target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 3.0e-4
- total_steps: null
- epochs: *epochs
- steps_per_epoch: 1354
- pct_start: 0.3
- anneal_strategy: cos
- cycle_momentum: true
- base_momentum: 0.85
- max_momentum: 0.95
- div_factor: 25.0
- final_div_factor: 10000.0
- three_phase: true
- last_epoch: -1
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
verbose: false
- interval: step
+ interval: epoch
monitor: val/cer
datamodule:
@@ -60,20 +59,66 @@ datamodule:
train_fraction: 0.95
network:
+ _target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 56, 1024]
- num_classes: *num_classes
- pad_index: *ignore_index
+ hidden_dim: &hidden_dim 128
+ num_classes: 58
+ pad_index: 3
encoder:
- depth: 5
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
pixel_embedding:
- shape: [3, 64]
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: 89
model:
max_output_len: *max_output_len
trainer:
- gradient_clip_val: 0.5
+ gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 41c236d..4bd3b45 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -19,10 +19,11 @@ summary: [[1, 1, 576, 640], [1, 682]]
logger:
wandb:
tags: ${tags}
+ id: 8je5lxmx
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.05
+ # label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
@@ -62,10 +63,6 @@ network:
input_dims: [1, 1, 576, 640]
num_classes: *num_classes
pad_index: *ignore_index
- encoder:
- depth: 4
- decoder:
- depth: 6
pixel_embedding:
shape: [18, 79]