summaryrefslogtreecommitdiff
path: root/training/conf/experiment/conv_transformer_lines.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:11:29 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-27 00:11:29 +0200
commita4546512c9a8dec632c94e506e1603d460ff0635 (patch)
tree1112daf28150ffe629878687424dd2211d3c2b85 /training/conf/experiment/conv_transformer_lines.yaml
parent91d0d49b9b0750f4b592d5fff6e440bb28f484dd (diff)
Update ex configs
Diffstat (limited to 'training/conf/experiment/conv_transformer_lines.yaml')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml29
1 files changed, 16 insertions, 13 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 2631e81..e0e426c 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -5,7 +5,6 @@ defaults:
- override /callbacks: htr
- override /datamodule: iam_lines
- override /network: null
- # - override /network: conv_transformer
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
@@ -23,7 +22,7 @@ logger:
criterion:
ignore_index: *ignore_index
- # label_smoothing: 0.1
+ # label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
@@ -36,7 +35,7 @@ callbacks:
optimizer:
_target_: adan_pytorch.Adan
- lr: 3.0e-4
+ lr: 1.0e-3
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
@@ -55,42 +54,43 @@ lr_scheduler:
monitor: val/cer
datamodule:
- batch_size: 8
+ batch_size: 16
train_fraction: 0.95
network:
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 56, 1024]
- hidden_dim: &hidden_dim 128
+ hidden_dim: &hidden_dim 384
num_classes: 58
pad_index: 3
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
- dim_mults: [2, 4, 8]
+ dim_mults: [2, 4, 24]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
- dim: 128
+ dim: *hidden_dim
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
- dim: 128
- mult: 4
+ dim: *hidden_dim
+ mult: 2
decoder:
_target_: text_recognizer.networks.transformer.Decoder
depth: 6
+ dim: *hidden_dim
block:
- _target_: text_recognizer.networks.transformer.DecoderBlock
+ _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 12
+ num_heads: 8
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
@@ -100,7 +100,7 @@ network:
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 12
+ num_heads: 8
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -119,7 +119,7 @@ network:
AxialPositionalEmbeddingImage"
dim: *hidden_dim
axial_shape: [7, 128]
- axial_dims: [64, 64]
+ axial_dims: [192, 192]
token_pos_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.fourier.\
PositionalEncoding"
@@ -134,3 +134,6 @@ trainer:
gradient_clip_val: 1.0
max_epochs: *epochs
accumulate_grad_batches: 1
+ limit_train_batches: 1.0
+ limit_val_batches: 1.0
+ limit_test_batches: 1.0