summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml19
1 files changed, 10 insertions, 9 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index d32c7d6..4e921f2 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -14,6 +14,7 @@ epochs: &epochs 260
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 89
+dim: &dim 384
# summary: [[1, 1, 56, 1024], [1, 89]]
logger:
@@ -71,13 +72,13 @@ network:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
- dim: *hidden_dim
+ dim: *dim
heads: 4
dim_head: 64
scale: 8
ff:
_target_: text_recognizer.networks.convnext.FeedForward
- dim: *hidden_dim
+ dim: *dim
mult: 2
pixel_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.axial.\
@@ -87,36 +88,36 @@ network:
axial_dims: [192, 192]
decoder:
_target_: text_recognizer.networks.text_decoder.TextDecoder
- hidden_dim: *hidden_dim
+ hidden_dim: *dim
num_classes: *num_classes
pad_index: *ignore_index
decoder:
_target_: text_recognizer.networks.transformer.Decoder
- dim: *hidden_dim
+ dim: *dim
depth: 6
block:
_target_: text_recognizer.networks.transformer.decoder_block.\
DecoderBlock
self_attn:
_target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
+ dim: *dim
num_heads: 10
dim_head: 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
- dim: *hidden_dim
+ dim: *dim
num_heads: 10
dim_head: 64
dropout_rate: *dropout_rate
causal: false
norm:
_target_: text_recognizer.networks.transformer.RMSNorm
- dim: *hidden_dim
+ dim: *dim
ff:
_target_: text_recognizer.networks.transformer.FeedForward
- dim: *hidden_dim
+ dim: *dim
dim_out: null
expansion_factor: 2
glu: true
@@ -127,7 +128,7 @@ network:
token_pos_embedding:
_target_: "text_recognizer.networks.transformer.embeddings.fourier.\
PositionalEncoding"
- dim: *hidden_dim
+ dim: *dim
dropout_rate: 0.1
max_len: *max_output_len