diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/callbacks/default.yaml | 2 | ||||
-rw-r--r-- | training/conf/callbacks/lightning/checkpoint.yaml | 8 | ||||
-rw-r--r-- | training/conf/experiment/conv_transformer_lines.yaml | 39 | ||||
-rw-r--r-- | training/conf/lr_schedulers/one_cycle.yaml | 37 | ||||
-rw-r--r-- | training/conf/model/lit_transformer.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 30 | ||||
-rw-r--r-- | training/conf/network/encoder/efficientnet.yaml | 5 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 2 |
9 files changed, 49 insertions, 78 deletions
diff --git a/training/conf/callbacks/default.yaml b/training/conf/callbacks/default.yaml index 57c10a6..4d8e399 100644 --- a/training/conf/callbacks/default.yaml +++ b/training/conf/callbacks/default.yaml @@ -2,5 +2,5 @@ defaults: - lightning/checkpoint - lightning/learning_rate_monitor - wandb/watch - - wandb/config + - wandb/config - wandb/checkpoints diff --git a/training/conf/callbacks/lightning/checkpoint.yaml b/training/conf/callbacks/lightning/checkpoint.yaml index b4101d8..9acd64f 100644 --- a/training/conf/callbacks/lightning/checkpoint.yaml +++ b/training/conf/callbacks/lightning/checkpoint.yaml @@ -1,9 +1,9 @@ model_checkpoint: _target_: pytorch_lightning.callbacks.ModelCheckpoint - monitor: val/loss # name of the logged metric which determines when model is improving - save_top_k: 1 # save k best models (determined by above metric) - save_last: true # additionaly always save model from last epoch - mode: min # can be "max" or "min" + monitor: val/cer + save_top_k: 1 + save_last: true + mode: min verbose: false dirpath: checkpoints/ filename: "{epoch:02d}" diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 8404cd1..38b13a5 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -18,7 +18,7 @@ summary: [[1, 1, 56, 1024], [1, 89]] criterion: ignore_index: *ignore_index - # label_smoothing: 0.1 + label_smoothing: 0.05 callbacks: stochastic_weight_averaging: @@ -40,30 +40,38 @@ optimizers: lr_schedulers: network: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.5 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 3.0e-4 + total_steps: null + epochs: *epochs + steps_per_epoch: 1284 + pct_start: 0.3 + anneal_strategy: cos + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.95 + div_factor: 25.0 + final_div_factor: 10000.0 + three_phase: true + last_epoch: -1 verbose: false - interval: epoch - monitor: val/loss + interval: step + monitor: val/cer datamodule: - batch_size: 16 + batch_size: 8 + train_fraction: 0.9 network: input_dims: [1, 1, 56, 1024] num_classes: *num_classes pad_index: *ignore_index + encoder: + depth: 5 decoder: - depth: 10 + depth: 6 pixel_embedding: - shape: [7, 128] + shape: [3, 64] model: max_output_len: *max_output_len @@ -71,3 +79,4 @@ model: trainer: gradient_clip_val: 0.5 max_epochs: *epochs + accumulate_grad_batches: 1 diff --git a/training/conf/lr_schedulers/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml index 801a01f..20eab9f 100644 --- a/training/conf/lr_schedulers/one_cycle.yaml +++ b/training/conf/lr_schedulers/one_cycle.yaml @@ -1,20 +1,17 @@ -one_cycle: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 1.0e-3 - total_steps: null - epochs: 512 - steps_per_epoch: 4992 - pct_start: 0.3 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 25.0 - final_div_factor: 10000.0 - three_phase: true - last_epoch: -1 - verbose: false - - # Non-class arguments - interval: step - monitor: val/loss +_target_: torch.optim.lr_scheduler.OneCycleLR +max_lr: 1.0e-3 +total_steps: null +epochs: 512 +steps_per_epoch: 4992 +pct_start: 0.3 +anneal_strategy: cos +cycle_momentum: true +base_momentum: 0.85 +max_momentum: 0.95 +div_factor: 25.0 +final_div_factor: 10000.0 +three_phase: true +last_epoch: -1 +verbose: false +interval: step +monitor: val/loss diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index c1491ec..b795078 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -5,4 +5,4 @@ end_token: <e> pad_token: <p> mapping: _target_: text_recognizer.data.mappings.EmnistMapping - extra_symbols: ["\n"] + # extra_symbols: ["\n"] diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 54eb028..39c5c46 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -10,7 +10,7 @@ encoder: bn_momentum: 0.99 bn_eps: 1.0e-3 depth: 3 - out_channels: 128 + out_channels: *hidden_dim decoder: _target_: text_recognizer.networks.transformer.Decoder depth: 6 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml deleted file mode 100644 index 4588ee9..0000000 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ /dev/null @@ -1,30 +0,0 @@ -_target_: text_recognizer.networks.transformer.decoder.Decoder -depth: 4 -block: - _target_: text_recognizer.networks.transformer.decoder.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.attention.Attention - dim: 64 - num_heads: 4 - dim_head: 64 - dropout_rate: 0.05 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.embeddings.rotary.RotaryEmbedding - dim: 128 - cross_attn: - _target_: text_recognizer.networks.transformer.attention.Attention - dim: 64 - num_heads: 4 - dim_head: 64 - dropout_rate: 0.05 - causal: false - norm: - _target_: text_recognizer.networks.transformer.norm.RMSNorm - normalized_shape: 192 - ff: - _target_: text_recognizer.networks.transformer.mlp.FeedForward - dim_out: null - expansion_factor: 4 - glu: true - dropout_rate: 0.2 diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml deleted file mode 100644 index a7be069..0000000 --- a/training/conf/network/encoder/efficientnet.yaml +++ /dev/null @@ -1,5 +0,0 @@ -_target_: text_recognizer.networks.efficientnet.EfficientNet -arch: b0 -stochastic_dropout_rate: 0.2 -bn_momentum: 0.99 -bn_eps: 1.0e-3 diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index d4ffcdc..c2d0d62 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -13,5 +13,5 @@ limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 resume_from_checkpoint: null -accumulate_grad_batches: 2 +accumulate_grad_batches: 1 overfit_batches: 0 |