summaryrefslogtreecommitdiff
path: root/training/conf/network/conv_transformer.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 19:09:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-13 19:09:25 +0200
commit2862fdf77a19c4afa5e00c900af4877df31a3ea6 (patch)
tree5a225df3113849b5c129b8ae42497cef383b4e93 /training/conf/network/conv_transformer.yaml
parentaaa26d7d550bb2d4ced4b87414fd314608148865 (diff)
Update configs
Diffstat (limited to 'training/conf/network/conv_transformer.yaml')
-rw-r--r--training/conf/network/conv_transformer.yaml30
1 files changed, 18 insertions, 12 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f618ba1..c71296b 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,19 +1,17 @@
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 144
+hidden_dim: &hidden_dim 128
num_classes: 58
pad_index: 3
encoder:
- _target_: text_recognizer.networks.EfficientNet
- arch: b0
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- depth: 5
- out_channels: *hidden_dim
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
decoder:
_target_: text_recognizer.networks.transformer.Decoder
- depth: 6
+ depth: 10
block:
_target_: text_recognizer.networks.transformer.DecoderBlock
self_attn:
@@ -29,7 +27,7 @@ decoder:
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 8
+ num_heads: 12
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -44,6 +42,14 @@ decoder:
glu: true
dropout_rate: *dropout_rate
pixel_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
dim: *hidden_dim
- shape: [18, 79]
+ dropout_rate: 0.1
+ max_len: 89