summaryrefslogtreecommitdiff
path: root/training/conf/network
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/network')
-rw-r--r--training/conf/network/conv_transformer.yaml30
-rw-r--r--training/conf/network/convnext.yaml5
2 files changed, 23 insertions, 12 deletions
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f618ba1..c71296b 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,19 +1,17 @@
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 144
+hidden_dim: &hidden_dim 128
num_classes: 58
pad_index: 3
encoder:
- _target_: text_recognizer.networks.EfficientNet
- arch: b0
- stochastic_dropout_rate: 0.2
- bn_momentum: 0.99
- bn_eps: 1.0e-3
- depth: 5
- out_channels: *hidden_dim
+ _target_: text_recognizer.networks.convnext.ConvNext
+ dim: 16
+ dim_mults: [2, 4, 8]
+ depths: [3, 3, 6]
+ downsampling_factors: [[2, 2], [2, 2], [2, 2]]
decoder:
_target_: text_recognizer.networks.transformer.Decoder
- depth: 6
+ depth: 10
block:
_target_: text_recognizer.networks.transformer.DecoderBlock
self_attn:
@@ -29,7 +27,7 @@ decoder:
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *hidden_dim
- num_heads: 8
+ num_heads: 12
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -44,6 +42,14 @@ decoder:
glu: true
dropout_rate: *dropout_rate
pixel_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ _target_: "text_recognizer.networks.transformer.embeddings.axial.\
+ AxialPositionalEmbeddingImage"
+ dim: *hidden_dim
+ axial_shape: [7, 128]
+ axial_dims: [64, 64]
+token_pos_embedding:
+ _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
+ PositionalEncoding"
dim: *hidden_dim
- shape: [18, 79]
+ dropout_rate: 0.1
+ max_len: 89
diff --git a/training/conf/network/convnext.yaml b/training/conf/network/convnext.yaml
new file mode 100644
index 0000000..bc1ce93
--- /dev/null
+++ b/training/conf/network/convnext.yaml
@@ -0,0 +1,5 @@
+_target_: text_recognizer.networks.convnext.ConvNext
+dim: 16
+dim_mults: [2, 4, 8]
+depths: [3, 3, 6]
+downsampling_factors: [[2, 2], [2, 2], [2, 2]]