summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml78
-rw-r--r--training/conf/network/convnext.yaml12
2 files changed, 80 insertions, 10 deletions
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 4bd3b45..34eedab 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -4,7 +4,7 @@ defaults:
- override /criterion: cross_entropy
- override /callbacks: htr
- override /datamodule: iam_extended_paragraphs
- - override /network: conv_transformer
+ - override /network: null
- override /model: lit_transformer
- override /lr_scheduler: null
- override /optimizer: null
@@ -35,11 +35,10 @@ callbacks:
device: null
optimizer:
- _target_: torch.optim.RAdam
+ _target_: adan_pytorch.Adan
lr: 3.0e-4
- betas: [0.9, 0.999]
- weight_decay: 0
- eps: 1.0e-8
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
lr_scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
@@ -60,14 +59,73 @@ datamodule:
train_fraction: 0.95
network:
+ _target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
+ hidden_dim: &hidden_dim 128
num_classes: *num_classes
- pad_index: *ignore_index
+ pad_index: 3
+ encoder:
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8, 8, 8]
+ depths: [3, 3, 3, 4, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2], [2, 1], [2, 1]]
+ attn:
+ _target_: text_recognizer.networks.convnext.TransformerBlock
+ attn:
+ _target_: text_recognizer.networks.convnext.Attention
+ dim: 128
+ heads: 4
+ dim_head: 64
+ scale: 8
+ ff:
+ _target_: text_recognizer.networks.convnext.FeedForward
+ dim: 128
+ mult: 4
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
pixel_embedding:
- shape: [18, 79]
-
-model:
- max_output_len: *max_output_len
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: 89
trainer:
gradient_clip_val: 1.0
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
index bc1ce93..63ad424 100644
--- a/training/conf/network/convnext.yaml
+++ b/training/conf/network/convnext.yaml
@@ -3,3 +3,15 @@ dim: 16
dim_mults: [2, 4, 8]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
+attn:
+ _target_: text_recognizer.networks.convnext.TransformerBlock
+ attn:
+ _target_: text_recognizer.networks.convnext.Attention
+ dim: 128
+ heads: 4
+ dim_head: 64
+ scale: 8
+ ff:
+ _target_: text_recognizer.networks.convnext.FeedForward
+ dim: 128
+ mult: 4