diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:19:39 +0200 |
commit | 6968572c1a21394b88a29f675b17b9698784a898 (patch) | |
tree | d89d1c5c2ec331d38dcb5b6a2dbbd72c9e355b8a /training/conf/network | |
parent | 49ca6ade1a19f7f9c702171537fe4be0dfcda66d (diff) |
Update training stuff
Diffstat (limited to 'training/conf/network')
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 26 | ||||
-rw-r--r-- | training/conf/network/convnext.yaml | 8 | ||||
-rw-r--r-- | training/conf/network/vit_lines.yaml | 37 |
3 files changed, 54 insertions, 17 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 016adbb..1e03946 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -1,58 +1,58 @@ -_target_: text_recognizer.networks.ConvTransformer +_target_: text_recognizer.network.ConvTransformer encoder: - _target_: text_recognizer.networks.image_encoder.ImageEncoder + _target_: text_recognizer.network.image_encoder.ImageEncoder encoder: - _target_: text_recognizer.networks.convnext.ConvNext + _target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] pixel_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.axial.\ + _target_: "text_recognizer.network.transformer.embeddings.axial.\ AxialPositionalEmbeddingImage" dim: &hidden_dim 128 axial_shape: [7, 128] axial_dims: [64, 64] decoder: - _target_: text_recognizer.networks.text_decoder.TextDecoder + _target_: text_recognizer.network.text_decoder.TextDecoder hidden_dim: *hidden_dim num_classes: 58 pad_index: 3 decoder: - _target_: text_recognizer.networks.transformer.Decoder + _target_: text_recognizer.network.transformer.Decoder dim: *hidden_dim depth: 10 block: - _target_: text_recognizer.networks.transformer.decoder_block.DecoderBlock + _target_: text_recognizer.network.transformer.decoder_block.DecoderBlock self_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: &dropout_rate 0.2 causal: true cross_attn: - _target_: text_recognizer.networks.transformer.Attention + _target_: text_recognizer.network.transformer.Attention dim: *hidden_dim num_heads: 12 dim_head: 64 dropout_rate: *dropout_rate causal: false norm: - _target_: text_recognizer.networks.transformer.RMSNorm + _target_: text_recognizer.network.transformer.RMSNorm dim: *hidden_dim ff: - _target_: text_recognizer.networks.transformer.FeedForward + _target_: text_recognizer.network.transformer.FeedForward dim: *hidden_dim dim_out: null expansion_factor: 2 glu: true dropout_rate: *dropout_rate rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding + _target_: text_recognizer.network.transformer.RotaryEmbedding dim: 64 token_pos_embedding: - _target_: "text_recognizer.networks.transformer.embeddings.fourier.\ + _target_: "text_recognizer.network.transformer.embeddings.fourier.\ PositionalEncoding" dim: *hidden_dim dropout_rate: 0.1 diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml index 63ad424..904bd56 100644 --- a/training/conf/network/convnext.yaml +++ b/training/conf/network/convnext.yaml @@ -1,17 +1,17 @@ -_target_: text_recognizer.networks.convnext.ConvNext +_target_: text_recognizer.network.convnext.ConvNext dim: 16 dim_mults: [2, 4, 8] depths: [3, 3, 6] downsampling_factors: [[2, 2], [2, 2], [2, 2]] attn: - _target_: text_recognizer.networks.convnext.TransformerBlock + _target_: text_recognizer.network.convnext.TransformerBlock attn: - _target_: text_recognizer.networks.convnext.Attention + _target_: text_recognizer.network.convnext.Attention dim: 128 heads: 4 dim_head: 64 scale: 8 ff: - _target_: text_recognizer.networks.convnext.FeedForward + _target_: text_recognizer.network.convnext.FeedForward dim: 128 mult: 4 diff --git a/training/conf/network/vit_lines.yaml b/training/conf/network/vit_lines.yaml new file mode 100644 index 0000000..35f83c3 --- /dev/null +++ b/training/conf/network/vit_lines.yaml @@ -0,0 +1,37 @@ +_target_: text_recognizer.network.vit.VisionTransformer +image_height: 56 +image_width: 1024 +patch_height: 28 +patch_width: 32 +dim: &dim 256 +num_classes: &num_classes 57 +encoder: + _target_: text_recognizer.network.transformer.encoder.Encoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +decoder: + _target_: text_recognizer.network.transformer.decoder.Decoder + dim: *dim + inner_dim: 1024 + heads: 8 + dim_head: 64 + depth: 6 + dropout_rate: 0.0 +token_embedding: + _target_: "text_recognizer.network.transformer.embedding.token.\ + TokenEmbedding" + num_tokens: *num_classes + dim: *dim + use_l2: true +pos_embedding: + _target_: "text_recognizer.network.transformer.embedding.absolute.\ + AbsolutePositionalEmbedding" + dim: *dim + max_length: 89 + use_l2: true +tie_embeddings: true +pad_index: 3 |