summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:35:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:35:34 +0200
commitbace9540792d517c1687d91f455f9503d9854609 (patch)
tree38aff5e65a9f7410930aa8a6a92bf7ecd7c89bd5 /training
parent9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (diff)
Update conv transformer model
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml136
-rw-r--r--training/conf/network/conv_transformer.yaml109
2 files changed, 126 insertions, 119 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index e0e426c..d32c7d6 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -12,7 +12,7 @@ defaults:
tags: [lines]
epochs: &epochs 260
ignore_index: &ignore_index 3
-num_classes: &num_classes 57
+num_classes: &num_classes 58
max_output_len: &max_output_len 89
# summary: [[1, 1, 56, 1024], [1, 89]]
@@ -35,7 +35,7 @@ callbacks:
optimizer:
_target_: adan_pytorch.Adan
- lr: 1.0e-3
+ lr: 3.0e-4
betas: [0.02, 0.08, 0.01]
weight_decay: 0.02
@@ -59,73 +59,77 @@ datamodule:
network:
_target_: text_recognizer.networks.ConvTransformer
- input_dims: [1, 1, 56, 1024]
- hidden_dim: &hidden_dim 384
- num_classes: 58
- pad_index: 3
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 24]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
- attn:
- _target_: text_recognizer.networks.convnext.TransformerBlock
+ _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ encoder:
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 24]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
- _target_: text_recognizer.networks.convnext.Attention
- dim: *hidden_dim
- heads: 4
- dim_head: 64
- scale: 8
- ff:
- _target_: text_recognizer.networks.convnext.FeedForward
- dim: *hidden_dim
- mult: 2
+ _target_: text_recognizer.networks.convnext.TransformerBlock
+ attn:
+ _target_: text_recognizer.networks.convnext.Attention
+ dim: *hidden_dim
+ heads: 4
+ dim_head: 64
+ scale: 8
+ ff:
+ _target_: text_recognizer.networks.convnext.FeedForward
+ dim: *hidden_dim
+ mult: 2
+ pixel_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: &hidden_dim 384
+ axial_shape: [7, 128]
+ axial_dims: [192, 192]
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- depth: 6
- dim: *hidden_dim
- block:
- _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
- self_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 8
- dim_head: 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
- dim: 64
- cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 8
- dim_head: 64
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
- dim: *hidden_dim
- ff:
- _target_: text_recognizer.networks.transformer.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
- pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *hidden_dim
- axial_shape: [7, 128]
- axial_dims: [192, 192]
- token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *hidden_dim
- dropout_rate: 0.1
- max_len: 89
+ _target_: text_recognizer.networks.text_decoder.TextDecoder
+ hidden_dim: *hidden_dim
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: *hidden_dim
+ depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.decoder_block.\
+ DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 10
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 10
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: *max_output_len
model:
max_output_len: *max_output_len
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 0ef862f..016adbb 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,56 +1,59 @@
_target_: text_recognizer.networks.ConvTransformer
-input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 128
-num_classes: 58
-pad_index: 3
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 8]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
+ _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ encoder:
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
+ pixel_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: &hidden_dim 128
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: *hidden_dim
- depth: 10
- block:
- _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
- self_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
- dim: 64
- cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
- dim: *hidden_dim
- ff:
- _target_: text_recognizer.networks.transformer.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
-pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *hidden_dim
- axial_shape: [7, 128]
- axial_dims: [64, 64]
-token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *hidden_dim
- dropout_rate: 0.1
- max_len: 89
+ _target_: text_recognizer.networks.text_decoder.TextDecoder
+ hidden_dim: *hidden_dim
+ num_classes: 58
+ pad_index: 3
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: *hidden_dim
+ depth: 10
+ block:
+ _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: 89