summaryrefslogtreecommitdiff
path: root/training/conf/network
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:35:34 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:35:34 +0200
commitbace9540792d517c1687d91f455f9503d9854609 (patch)
tree38aff5e65a9f7410930aa8a6a92bf7ecd7c89bd5 /training/conf/network
parent9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (diff)
Update conv transformer model
Diffstat (limited to 'training/conf/network')
-rw-r--r--training/conf/network/conv_transformer.yaml109
1 files changed, 56 insertions, 53 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 0ef862f..016adbb 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,56 +1,59 @@
_target_: text_recognizer.networks.ConvTransformer
-input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 128
-num_classes: 58
-pad_index: 3
encoder:
- _target_: text_recognizer.networks.convnext.ConvNext
- dim: 16
- dim_mults: [2, 4, 8]
- depths: [3, 3, 6]
- downsampling_factors: [[2, 2], [2, 2], [2, 2]]
+ _target_: text_recognizer.networks.image_encoder.ImageEncoder
+ encoder:
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
+ pixel_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: &hidden_dim 128
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: *hidden_dim
- depth: 10
- block:
- _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
- self_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: &dropout_rate 0.2
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.RotaryEmbedding
- dim: 64
- cross_attn:
- _target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
- num_heads: 12
- dim_head: 64
- dropout_rate: *dropout_rate
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.RMSNorm
- dim: *hidden_dim
- ff:
- _target_: text_recognizer.networks.transformer.FeedForward
- dim: *hidden_dim
- dim_out: null
- expansion_factor: 2
- glu: true
- dropout_rate: *dropout_rate
-pixel_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.axial.\
- AxialPositionalEmbeddingImage"
- dim: *hidden_dim
- axial_shape: [7, 128]
- axial_dims: [64, 64]
-token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *hidden_dim
- dropout_rate: 0.1
- max_len: 89
+ _target_: text_recognizer.networks.text_decoder.TextDecoder
+ hidden_dim: *hidden_dim
+ num_classes: 58
+ pad_index: 3
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: *hidden_dim
+ depth: 10
+ block:
+ _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.2
+ causal: true
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 12
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
+ dim: *hidden_dim
+ dropout_rate: 0.1
+ max_len: 89