summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml29
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml16
2 files changed, 25 insertions, 20 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 2631e81..e0e426c 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -5,7 +5,6 @@ defaults:
- override /callbacks: htr
- override /datamodule: iam_lines
- override /network: null
- # - override /network: conv_transformer
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
@@ -23,7 +22,7 @@ logger:
criterion:
ignore_index: *ignore_index
- # label_smoothing: 0.1
+ # label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
@@ -36,7 +35,7 @@ callbacks:
optimizer:
_target_: adan_pytorch.Adan
- lr: 3.0e-4
+ lr: 1.0e-3
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
@@ -55,42 +54,43 @@ lr_scheduler:
monitor: val/cer
datamodule:
- batch_size: 8
+ batch_size: 16
train_fraction: 0.95
network:
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 56, 1024]
- hidden_dim: &hidden_dim 128
+ hidden_dim: &hidden_dim 384
num_classes: 58
pad_index: 3
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
- dim_mults: [2, 4, 8]
+ dim_mults: [2, 4, 24]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
- dim: 128
+ dim: *hidden_dim
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
- dim: 128
- mult: 4
+ dim: *hidden_dim
+ mult: 2
decoder:
_target_: text_recognizer.networks.transformer.Decoder
depth: 6
+ dim: *hidden_dim
block:
- _target_: text_recognizer.networks.transformer.DecoderBlock
+ _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 12
+ num_heads: 8
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
@@ -100,7 +100,7 @@ network:
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 12
+ num_heads: 8
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -119,7 +119,7 @@ network:
AxialPositionalEmbeddingImage"
dim: *hidden_dim
axial_shape: [7, 128]
- axial_dims: [64, 64]
+ axial_dims: [192, 192]
token_pos_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.fourier.\
PositionalEncoding"
@@ -134,3 +134,6 @@ trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 60898da..60ff1bf 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -13,13 +13,12 @@ tags: [paragraphs]
epochs: &epochs 600
num_classes: &num_classes 58
ignore_index: &ignore_index 3
-max_output_len: &max_output_len 682
+# max_output_len: &max_output_len 682
# summary: [[1, 1, 576, 640], [1, 682]]
logger:
wandb:
tags: ${tags}
- id: 8je5lxmx
criterion:
ignore_index: *ignore_index
@@ -67,9 +66,9 @@ network:
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
- dim_mults: [2, 4, 8, 8]
- depths: [3, 3, 6, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1]]
+ dim_mults: [1, 2, 4, 8, 8]
+ depths: [3, 3, 3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 1], [2, 1], [2, 1]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
@@ -118,7 +117,7 @@ network:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
dim: *hidden_dim
- axial_shape: [36, 80]
+ axial_shape: [18, 160]
axial_dims: [64, 64]
token_pos_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.fourier.\
@@ -130,4 +129,7 @@ network:
trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
- accumulate_grad_batches: 6
+ accumulate_grad_batches: 8
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0