summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/default.yaml2
-rw-r--r--training/conf/callbacks/lightning/checkpoint.yaml8
-rw-r--r--training/conf/experiment/conv_transformer_lines.yaml39
-rw-r--r--training/conf/lr_schedulers/one_cycle.yaml37
-rw-r--r--training/conf/model/lit_transformer.yaml2
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml30
-rw-r--r--training/conf/network/encoder/efficientnet.yaml5
-rw-r--r--training/conf/trainer/default.yaml2
9 files changed, 49 insertions, 78 deletions
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml
index 57c10a6..4d8e399 100644
--- a/training/conf/callbacks/default.yaml
+++ b/training/conf/callbacks/default.yaml
@@ -2,5 +2,5 @@ defaults:
- lightning/checkpoint
- lightning/learning_rate_monitor
- wandb/watch
- - wandb/config
+ - wandb/config
- wandb/checkpoints
diff --git a/training/conf/callbacks/lightning/checkpoint.yaml b/training/conf/callbacks/lightning/checkpoint.yaml
index b4101d8..9acd64f 100644
--- a/training/conf/callbacks/lightning/checkpoint.yaml
+++ b/training/conf/callbacks/lightning/checkpoint.yaml
@@ -1,9 +1,9 @@
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
- monitor: val/loss # name of the logged metric which determines when model is improving
- save_top_k: 1 # save k best models (determined by above metric)
- save_last: true # additionaly always save model from last epoch
- mode: min # can be "max" or "min"
+ monitor: val/cer
+ save_top_k: 1
+ save_last: true
+ mode: min
verbose: false
dirpath: checkpoints/
filename: "{epoch:02d}"
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 8404cd1..38b13a5 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -18,7 +18,7 @@ summary: [[1, 1, 56, 1024], [1, 89]]
criterion:
ignore_index: *ignore_index
- # label_smoothing: 0.1
+ label_smoothing: 0.05
callbacks:
stochastic_weight_averaging:
@@ -40,30 +40,38 @@ optimizers:
lr_schedulers:
network:
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
- mode: min
- factor: 0.5
- patience: 10
- threshold: 1.0e-4
- threshold_mode: rel
- cooldown: 0
- min_lr: 1.0e-5
- eps: 1.0e-8
+ _target_: torch.optim.lr_scheduler.OneCycleLR
+ max_lr: 3.0e-4
+ total_steps: null
+ epochs: *epochs
+ steps_per_epoch: 1284
+ pct_start: 0.3
+ anneal_strategy: cos
+ cycle_momentum: true
+ base_momentum: 0.85
+ max_momentum: 0.95
+ div_factor: 25.0
+ final_div_factor: 10000.0
+ three_phase: true
+ last_epoch: -1
verbose: false
- interval: epoch
- monitor: val/loss
+ interval: step
+ monitor: val/cer
datamodule:
- batch_size: 16
+ batch_size: 8
+ train_fraction: 0.9
network:
input_dims: [1, 1, 56, 1024]
num_classes: *num_classes
pad_index: *ignore_index
+ encoder:
+ depth: 5
decoder:
- depth: 10
+ depth: 6
pixel_embedding:
- shape: [7, 128]
+ shape: [3, 64]
model:
max_output_len: *max_output_len
@@ -71,3 +79,4 @@ model:
trainer:
gradient_clip_val: 0.5
max_epochs: *epochs
+ accumulate_grad_batches: 1
diff --git a/training/conf/lr_schedulers/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml
index 801a01f..20eab9f 100644
--- a/training/conf/lr_schedulers/one_cycle.yaml
+++ b/training/conf/lr_schedulers/one_cycle.yaml
@@ -1,20 +1,17 @@
-one_cycle:
- _target_: torch.optim.lr_scheduler.OneCycleLR
- max_lr: 1.0e-3
- total_steps: null
- epochs: 512
- steps_per_epoch: 4992
- pct_start: 0.3
- anneal_strategy: cos
- cycle_momentum: true
- base_momentum: 0.85
- max_momentum: 0.95
- div_factor: 25.0
- final_div_factor: 10000.0
- three_phase: true
- last_epoch: -1
- verbose: false
-
- # Non-class arguments
- interval: step
- monitor: val/loss
+_target_: torch.optim.lr_scheduler.OneCycleLR
+max_lr: 1.0e-3
+total_steps: null
+epochs: 512
+steps_per_epoch: 4992
+pct_start: 0.3
+anneal_strategy: cos
+cycle_momentum: true
+base_momentum: 0.85
+max_momentum: 0.95
+div_factor: 25.0
+final_div_factor: 10000.0
+three_phase: true
+last_epoch: -1
+verbose: false
+interval: step
+monitor: val/loss
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index c1491ec..b795078 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -5,4 +5,4 @@ end_token: <e>
pad_token: <p>
mapping:
_target_: text_recognizer.data.mappings.EmnistMapping
- extra_symbols: ["\n"]
+ # extra_symbols: ["\n"]
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 54eb028..39c5c46 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -10,7 +10,7 @@ encoder:
bn_momentum: 0.99
bn_eps: 1.0e-3
depth: 3
- out_channels: 128
+ out_channels: *hidden_dim
decoder:
_target_: text_recognizer.networks.transformer.Decoder
depth: 6
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
deleted file mode 100644
index 4588ee9..0000000
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-_target_: text_recognizer.networks.transformer.decoder.Decoder
-depth: 4
-block:
- _target_: text_recognizer.networks.transformer.decoder.DecoderBlock
- self_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- dim: 64
- num_heads: 4
- dim_head: 64
- dropout_rate: 0.05
- causal: true
- rotary_embedding:
- _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding
- dim: 128
- cross_attn:
- _target_: text_recognizer.networks.transformer.attention.Attention
- dim: 64
- num_heads: 4
- dim_head: 64
- dropout_rate: 0.05
- causal: false
- norm:
- _target_: text_recognizer.networks.transformer.norm.RMSNorm
- normalized_shape: 192
- ff:
- _target_: text_recognizer.networks.transformer.mlp.FeedForward
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml
deleted file mode 100644
index a7be069..0000000
--- a/training/conf/network/encoder/efficientnet.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-_target_: text_recognizer.networks.efficientnet.EfficientNet
-arch: b0
-stochastic_dropout_rate: 0.2
-bn_momentum: 0.99
-bn_eps: 1.0e-3
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index d4ffcdc..c2d0d62 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -13,5 +13,5 @@ limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
resume_from_checkpoint: null
-accumulate_grad_batches: 2
+accumulate_grad_batches: 1
overfit_batches: 0