summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:16:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:16:39 +0200
commit3e24b92ee1bac124ea8c7bddb15236ccc5fe300d (patch)
tree81803764977345b9264b166926024559908cb066 /training/conf
parent4a6550ddef7d1f1971737bc22715db6381441f79 (diff)
Update to configs
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/experiment/barlow_twins_paragraphs.yaml4
-rw-r--r--training/conf/experiment/cnn_htr_wp_lines.yaml2
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs.yaml46
-rw-r--r--training/conf/experiment/conv_transformer_paragraphs_wp.yaml6
-rw-r--r--training/conf/mapping/characters.yaml2
-rw-r--r--training/conf/mapping/word_piece.yaml2
6 files changed, 33 insertions, 29 deletions
diff --git a/training/conf/experiment/barlow_twins_paragraphs.yaml b/training/conf/experiment/barlow_twins_paragraphs.yaml
index caefb47..9552c0b 100644
--- a/training/conf/experiment/barlow_twins_paragraphs.yaml
+++ b/training/conf/experiment/barlow_twins_paragraphs.yaml
@@ -37,10 +37,10 @@ optimizers:
lr_schedulers:
network:
_target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 1.0e-1
+ max_lr: 1.0e-3
total_steps: null
epochs: *epochs
- steps_per_epoch: 5053
+ steps_per_epoch: 40
pct_start: 0.03
anneal_strategy: cos
cycle_momentum: true
diff --git a/training/conf/experiment/cnn_htr_wp_lines.yaml b/training/conf/experiment/cnn_htr_wp_lines.yaml
index 6cdd023..9f1164a 100644
--- a/training/conf/experiment/cnn_htr_wp_lines.yaml
+++ b/training/conf/experiment/cnn_htr_wp_lines.yaml
@@ -117,7 +117,7 @@ network:
cross_attend: true
pre_norm: true
rotary_emb:
- _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+ _target_: text_recognizer.networks.transformer.positional_encodings.rotary.RotaryEmbedding
dim: 32
pixel_pos_embedding:
_target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
diff --git a/training/conf/experiment/conv_transformer_paragraphs.yaml b/training/conf/experiment/conv_transformer_paragraphs.yaml
index 9e8bc50..4de9722 100644
--- a/training/conf/experiment/conv_transformer_paragraphs.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs.yaml
@@ -1,4 +1,4 @@
-# @package _glbal_
+# @package _global_
defaults:
- override /mapping: null
@@ -11,7 +11,7 @@ defaults:
- override /optimizers: null
-epochs: &epochs 1000
+epochs: &epochs 720
ignore_index: &ignore_index 3
num_classes: &num_classes 58
max_output_len: &max_output_len 682
@@ -23,7 +23,7 @@ criterion:
mapping: &mapping
mapping:
- _target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
+ _target_: text_recognizer.data.mappings.emnist.EmnistMapping
extra_symbols: [ "\n" ]
callbacks:
@@ -38,11 +38,10 @@ callbacks:
optimizers:
madgrad:
_target_: madgrad.MADGRAD
- lr: 2.0e-4
+ lr: 1.0e-4
momentum: 0.9
weight_decay: 5.0e-6
eps: 1.0e-6
-
parameters: network
lr_schedulers:
@@ -51,18 +50,17 @@ lr_schedulers:
max_lr: 1.0e-4
total_steps: null
epochs: *epochs
- steps_per_epoch: 632
+ steps_per_epoch: 316
pct_start: 0.03
anneal_strategy: cos
cycle_momentum: true
base_momentum: 0.85
max_momentum: 0.95
div_factor: 25
- final_div_factor: 1.0e4
+ final_div_factor: 1.0e2
three_phase: false
last_epoch: -1
verbose: false
- # Non-class arguments
interval: step
monitor: val/loss
@@ -79,7 +77,7 @@ network:
input_dims: [1, 576, 640]
hidden_dim: &hidden_dim 256
encoder_dim: 1280
- dropout_rate: 0.2
+ dropout_rate: 0.1
num_classes: *num_classes
pad_index: *ignore_index
encoder:
@@ -90,35 +88,41 @@ network:
bn_momentum: 0.99
bn_eps: 1.0e-3
decoder:
- _target_: text_recognizer.networks.transformer.Decoder
+ _target_: text_recognizer.networks.transformer.layers.Decoder
dim: *hidden_dim
- depth: 3
+ depth: 3
num_heads: 4
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
dim_head: 32
- dropout_rate: 0.2
+ dropout_rate: 0.05
+ local_attn_fn: text_recognizer.networks.transformer.local_attention.LocalAttention
+ local_attn_kwargs:
+ dim_head: 32
+ dropout_rate: 0.05
+ window_size: 11
+ look_back: 2
+ depth: 2
norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
ff_kwargs:
dim_out: null
expansion_factor: 4
glu: true
- dropout_rate: 0.2
+ dropout_rate: 0.05
cross_attend: true
pre_norm: true
rotary_emb:
- _target_: text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding
+ _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
dim: 32
pixel_pos_embedding:
- _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding2D
- hidden_dim: *hidden_dim
- max_h: 18
- max_w: 20
+ _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbedding
+ dim: *hidden_dim
+ shape: [18, 20]
token_pos_embedding:
- _target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
+ _target_: text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding
hidden_dim: *hidden_dim
- dropout_rate: 0.2
+ dropout_rate: 0.05
max_len: *max_output_len
model:
@@ -134,7 +138,7 @@ trainer:
stochastic_weight_avg: true
auto_scale_batch_size: binsearch
auto_lr_find: false
- gradient_clip_val: 0.75
+ gradient_clip_val: 0.5
fast_dev_run: false
gpus: 1
precision: 16
diff --git a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
index ebaa17a..91fba9a 100644
--- a/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
+++ b/training/conf/experiment/conv_transformer_paragraphs_wp.yaml
@@ -103,14 +103,14 @@ network:
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
dim_head: 32
- dropout_rate: 0.2
+ dropout_rate: 0.05
norm_fn: text_recognizer.networks.transformer.norm.ScaleNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
ff_kwargs:
dim_out: null
expansion_factor: 4
glu: true
- dropout_rate: 0.2
+ dropout_rate: 0.05
cross_attend: true
pre_norm: true
rotary_emb:
@@ -124,7 +124,7 @@ network:
token_pos_embedding:
_target_: text_recognizer.networks.transformer.positional_encodings.PositionalEncoding
hidden_dim: *hidden_dim
- dropout_rate: 0.2
+ dropout_rate: 0.05
max_len: *max_output_len
model:
diff --git a/training/conf/mapping/characters.yaml b/training/conf/mapping/characters.yaml
index 41a26a3..d91c9e5 100644
--- a/training/conf/mapping/characters.yaml
+++ b/training/conf/mapping/characters.yaml
@@ -1,2 +1,2 @@
-_target_: text_recognizer.data.mappings.emnist_mapping.EmnistMapping
+_target_: text_recognizer.data.mappings.emnist.EmnistMapping
extra_symbols: [ "\n" ]
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index c005cc4..6b4dc07 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.data.mappings.word_piece_mapping.WordPieceMapping
+_target_: text_recognizer.data.mappings.word_piece.WordPieceMapping
num_features: 1000
tokens: iamdb_1kwp_tokens_1000.txt
lexicon: iamdb_1kwp_lex_1000.txt