summaryrefslogtreecommitdiff
path: root/training/conf/network
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
commit12abf17cd7c31ae4599be366505a4423fbba4044 (patch)
tree996e5d549ebbb7d22f5acfdcd321bddec77f98d1 /training/conf/network
parent16e2e420e077253c3b2bc414283281fea557717d (diff)
Update perceiver conf
Diffstat (limited to 'training/conf/network')
-rw-r--r--training/conf/network/conv_perceiver.yaml23
-rw-r--r--training/conf/network/conv_transformer.yaml2
2 files changed, 16 insertions, 9 deletions
diff --git a/training/conf/network/conv_perceiver.yaml b/training/conf/network/conv_perceiver.yaml
index e6906fa..2e12db9 100644
--- a/training/conf/network/conv_perceiver.yaml
+++ b/training/conf/network/conv_perceiver.yaml
@@ -1,9 +1,10 @@
_target_: text_recognizer.networks.ConvPerceiver
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 144
+hidden_dim: &hidden_dim 128
num_classes: &num_classes 58
-queries_dim: &queries_dim 16
-max_length: 89
+max_length: &max_length 89
+num_queries: *max_length
+queries_dim: &queries_dim 64
pad_index: 3
encoder:
_target_: text_recognizer.networks.EfficientNet
@@ -15,16 +16,22 @@ encoder:
out_channels: *hidden_dim
decoder:
_target_: text_recognizer.networks.perceiver.PerceiverIO
- dim: *hidden_dim
+ dim: 192
cross_heads: 1
cross_head_dim: 64
num_latents: 256
latent_dim: 512
latent_heads: 8
depth: 6
- queries_dim: *queries_dim
+ queries_dim: 128
logits_dim: *num_classes
pixel_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
- dim: *hidden_dim
- shape: [3, 64]
+ _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage
+ dim: 64
+ axial_shape: [3, 64]
+ axial_dims: [32, 32]
+query_pos_emb:
+ _target_: text_recognizer.networks.transformer.embeddings.absolute.AbsolutePositionalEmbedding
+ dim: 64
+ max_seq_len: *max_length
+ l2norm_embed: true
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index c4cf13e..f618ba1 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,6 +1,6 @@
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 96
+hidden_dim: &hidden_dim 144
num_classes: 58
pad_index: 3
encoder: