summaryrefslogtreecommitdiff
path: root/training/conf/experiment
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 01:12:13 +0200
commitc614c472707910658b86bb28b9f02062e6982999 (patch)
treebd043a8196f9ee3e5339ec7be17116c0ba0cc1ef /training/conf/experiment
parent03029695897fff72c9e7a66a3f986877ebb0b0ff (diff)
Make rotary pos encoding mandatory
Diffstat (limited to 'training/conf/experiment')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml16
1 files changed, 5 insertions, 11 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 4e921f2..3392cd6 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -83,7 +83,7 @@ network:
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
AxialPositionalEmbeddingImage"
- dim: &hidden_dim 384
+ dim: *dim
axial_shape: [7, 128]
axial_dims: [192, 192]
decoder:
@@ -96,19 +96,19 @@ network:
dim: *dim
depth: 6
block:
- _target_: text_recognizer.networks.transformer.decoder_block.\
- DecoderBlock
+ _target_: "text_recognizer.networks.transformer.decoder_block.\
+ DecoderBlock"
self_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
- num_heads: 10
+ num_heads: 8
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
- num_heads: 10
+ num_heads: 8
dim_head: 64
dropout_rate: *dropout_rate
causal: false
@@ -125,12 +125,6 @@ network:
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
dim: 64
- token_pos_embedding:
- _target_: "text_recognizer.networks.transformer.embeddings.fourier.\
- PositionalEncoding"
- dim: *dim
- dropout_rate: 0.1
- max_len: *max_output_len
model:
max_output_len: *max_output_len