summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:54:35 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 02:54:35 +0200
commit20b851f1918bcfc122edd1161a33fd496f82ee86 (patch)
tree10abce18442a04a14f808bd8c00ab305b7dfe308 /training
parentffec11ce67d8fe75ea0d5dde5ddf17eb1017fa7d (diff)
Update lines experiment conf
Diffstat (limited to 'training')
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml25
1 files changed, 12 insertions, 13 deletions
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 3392cd6..3f5da86 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -10,11 +10,8 @@ defaults:
- override /optimizer: null
tags: [lines]
-epochs: &epochs 260
+epochs: &epochs 64
ignore_index: &ignore_index 3
-num_classes: &num_classes 58
-max_output_len: &max_output_len 89
-dim: &dim 384
# summary: [[1, 1, 56, 1024], [1, 89]]
logger:
@@ -57,6 +54,9 @@ lr_scheduler:
datamodule:
batch_size: 16
train_fraction: 0.95
+ transform:
+ _target_: text_recognizer.data.stems.line.IamLinesStem
+ augment: false
network:
_target_: text_recognizer.networks.ConvTransformer
@@ -65,14 +65,14 @@ network:
encoder:
_target_: text_recognizer.networks.convnext.ConvNext
dim: 16
- dim_mults: [2, 4, 24]
+ dim_mults: [2, 4, 32]
depths: [3, 3, 6]
downsampling_factors: [[2, 2], [2, 2], [2, 2]]
attn:
_target_: text_recognizer.networks.convnext.TransformerBlock
attn:
_target_: text_recognizer.networks.convnext.Attention
- dim: *dim
+ dim: &dim 512
heads: 4
dim_head: 64
scale: 8
@@ -85,11 +85,10 @@ network:
AxialPositionalEmbeddingImage"
dim: *dim
axial_shape: [7, 128]
- axial_dims: [192, 192]
decoder:
_target_: text_recognizer.networks.text_decoder.TextDecoder
- hidden_dim: *dim
- num_classes: *num_classes
+ dim: *dim
+ num_classes: 58
pad_index: *ignore_index
decoder:
_target_: text_recognizer.networks.transformer.Decoder
@@ -102,14 +101,14 @@ network:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
- dim_head: 64
+ dim_head: &dim_head 64
dropout_rate: &dropout_rate 0.2
causal: true
cross_attn:
_target_: text_recognizer.networks.transformer.Attention
dim: *dim
num_heads: 8
- dim_head: 64
+ dim_head: *dim_head
dropout_rate: *dropout_rate
causal: false
norm:
@@ -124,10 +123,10 @@ network:
dropout_rate: *dropout_rate
rotary_embedding:
_target_: text_recognizer.networks.transformer.RotaryEmbedding
- dim: 64
+ dim: *dim_head
model:
- max_output_len: *max_output_len
+ max_output_len: 89
trainer:
gradient_clip_val: 1.0