summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml101
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml7
-rw-r--r--training/conf/network/conv_transformer.yaml30
-rw-r--r--training/conf/network/convnext.yaml5
-rw-r--r--training/conf/trainer/default.yaml3
5 files changed, 98 insertions, 48 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index eb9bc9e..d4478cc 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -4,20 +4,26 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_lines
- - override /network: conv_transformer
+ - override /network: null
+ # - override /network: conv_transformer
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
-epochs: &epochs 512
+tags: [lines]
+epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 57
max_output_len: &max_output_len 89
summary: [[1, 1, 56, 1024], [1, 89]]
+logger:
+ wandb:
+ tags: ${tags}
+
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.05
+ # label_smoothing: 0.1
callbacks:
stochastic_weight_averaging:
@@ -29,30 +35,23 @@ callbacks:
device: null
optimizer:
- _target_: torch.optim.RAdam
+ _target_: adan_pytorch.Adan
lr: 3.0e-4
- betas: [0.9, 0.999]
- weight_decay: 0
- eps: 1.0e-8
- parameters: network
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
lr_scheduler:
- _target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 3.0e-4
- total_steps: null
- epochs: *epochs
- steps_per_epoch: 1354
- pct_start: 0.3
- anneal_strategy: cos
- cycle_momentum: true
- base_momentum: 0.85
- max_momentum: 0.95
- div_factor: 25.0
- final_div_factor: 10000.0
- three_phase: true
- last_epoch: -1
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
verbose: false
- interval: step
+ interval: epoch
monitor: val/cer
datamodule:
@@ -60,20 +59,66 @@ datamodule:
train_fraction: 0.95
network:
+ _target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 56, 1024]
- num_classes: *num_classes
- pad_index: *ignore_index
+ hidden_dim: &hidden_dim 128
+ num_classes: 58
+ pad_index: 3
encoder:
- depth: 5
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
pixel_embedding:
- shape: [3, 64]
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: 89
model:
max_output_len: *max_output_len
trainer:
- gradient_clip_val: 0.5
+ gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 41c236d..4bd3b45 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -19,10 +19,11 @@ summary: [[1, 1, 576, 640], [1, 682]]
logger:
wandb:
tags: ${tags}
+ id: 8je5lxmx
criterion:
ignore_index: *ignore_index
- label_smoothing: 0.05
+ # label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
@@ -62,10 +63,6 @@ network:
input_dims: [1, 1, 576, 640]
num_classes: *num_classes
pad_index: *ignore_index
- encoder:
- depth: 4
- decoder:
- depth: 6
pixel_embedding:
shape: [18, 79]
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f618ba1..c71296b 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,19 +1,17 @@
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 144
+hidden_dim: &hidden_dim 128
num_classes: 58
pad_index: 3
encoder:
- _target_: text_recognizer.networks.EfficientNet
- arch: b0
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- depth: 5
- out_channels: *hidden_dim
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
decoder:
_target_: text_recognizer.networks.transformer.Decoder
- depth: 6
+ depth: 10
block:
_target_: text_recognizer.networks.transformer.DecoderBlock
self_attn:
@@ -29,7 +27,7 @@ decoder:
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 8
+ num_heads: 12
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -44,6 +42,14 @@ decoder:
glu: true
dropout_rate: *dropout_rate
pixel_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
dim: *hidden_dim
- shape: [18, 79]
+ dropout_rate: 0.1
+ max_len: 89
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
new file mode 100644
index 0000000..bc1ce93
--- /dev/null
+++ b/training/conf/network/convnext.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.convnext.ConvNext
+dim: 16
+dim_mults: [2, 4, 8]
+depths: [3, 3, 6]
+downsampling_factors: [[2, 2], [2, 2], [2, 2]]
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index c2d0d62..6112cd8 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -1,5 +1,4 @@
_target_: pytorch_lightning.Trainer
-stochastic_weight_avg: true
auto_scale_batch_size: binsearch
auto_lr_find: false
gradient_clip_val: 0.5
@@ -7,8 +6,6 @@ fast_dev_run: false
gpus: 1
precision: 16
max_epochs: 256
-terminate_on_nan: true
-weights_summary: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0