summaryrefslogtreecommitdiff
path: root/training/conf/network
diff options
context:
space:
mode:
Diffstat (limited to 'training/conf/network')
-rw-r--r--training/conf/network/vq_transformer.yaml65
1 files changed, 65 insertions, 0 deletions
diff --git a/training/conf/network/vq_transformer.yaml b/training/conf/network/vq_transformer.yaml
new file mode 100644
index 0000000..d62a4b7
--- /dev/null
+++ b/training/conf/network/vq_transformer.yaml
@@ -0,0 +1,65 @@
+_target_: text_recognizer.networks.VqTransformer
+input_dims: [1, 1, 576, 640]
+hidden_dim: &hidden_dim 144
+num_classes: 58
+pad_index: 3
+encoder:
+ _target_: text_recognizer.networks.EfficientNet
+ arch: b0
+ stochastic_dropout_rate: 0.2
+ bn_momentum: 0.99
+ bn_eps: 1.0e-3
+ depth: 5
+ out_channels: *hidden_dim
+decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ depth: 6
+ block:
+ _target_: text_recognizer.networks.transformer.DecoderBlock
+ self_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 8
+ dim_head: 64
+ dropout_rate: &dropout_rate 0.4
+ causal: true
+ rotary_embedding:
+ _target_: text_recognizer.networks.transformer.RotaryEmbedding
+ dim: 64
+ cross_attn:
+ _target_: text_recognizer.networks.transformer.Attention
+ dim: *hidden_dim
+ num_heads: 8
+ dim_head: 64
+ dropout_rate: *dropout_rate
+ causal: false
+ norm:
+ _target_: text_recognizer.networks.transformer.RMSNorm
+ dim: *hidden_dim
+ ff:
+ _target_: text_recognizer.networks.transformer.FeedForward
+ dim: *hidden_dim
+ dim_out: null
+ expansion_factor: 2
+ glu: true
+ dropout_rate: *dropout_rate
+pixel_embedding:
+ _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: [18, 79]
+quantizer:
+ _target_: text_recognizer.networks.quantizer.VectorQuantizer
+ input_dim: *hidden_dim
+ codebook:
+ _target_: text_recognizer.networks.quantizer.CosineSimilarityCodebook
+ dim: 16
+ codebook_size: 64
+ kmeans_init: true
+ kmeans_iters: 10
+ decay: 0.8
+ eps: 1.0e-5
+ threshold_dead: 2
+ temperature: 0.0
+ commitment: 0.25
+ ort_reg_weight: 10
+ ort_reg_max_codes: 64