diff options
Diffstat (limited to 'training/conf')
15 files changed, 35 insertions, 16 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml index db34cb1..b4101d8 100644 --- a/training/conf/callbacks/checkpoint.yaml +++ b/training/conf/callbacks/checkpoint.yaml @@ -6,4 +6,4 @@ model_checkpoint: mode: min # can be "max" or "min" verbose: false dirpath: checkpoints/ - filename: {epoch:02d} + filename: "{epoch:02d}" diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml index a4a16ff..a4a16ff 100644 --- a/training/conf/callbacks/wandb/checkpoints.yaml +++ b/training/conf/callbacks/wandb_checkpoints.yaml diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb_code.yaml index 35f6ea3..35f6ea3 100644 --- a/training/conf/callbacks/wandb/code.yaml +++ b/training/conf/callbacks/wandb_code.yaml diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml index e69de29..e69de29 100644 --- a/training/conf/callbacks/wandb/image_reconstructions.yaml +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml index efa3dda..9c9a6da 100644 --- a/training/conf/callbacks/wandb_ocr.yaml +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -1,6 +1,6 @@ defaults: - default - - wandb/watch - - wandb/code - - wandb/checkpoints - - wandb/ocr_predictions + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_ocr_predictions diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml index 573fa96..573fa96 100644 --- a/training/conf/callbacks/wandb/ocr_predictions.yaml +++ b/training/conf/callbacks/wandb_ocr_predictions.yaml diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb_watch.yaml index 511608c..511608c 100644 --- a/training/conf/callbacks/wandb/watch.yaml +++ b/training/conf/callbacks/wandb_watch.yaml diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 93215ed..782bcbb 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,8 +1,9 @@ defaults: - callbacks: wandb_ocr - criterion: label_smoothing - - dataset: iam_extended_paragraphs + - datamodule: iam_extended_paragraphs - hydra: default + - logger: wandb - lr_scheduler: one_cycle - mapping: word_piece - model: lit_transformer @@ -15,3 +16,21 @@ tune: false train: true test: true logging: INFO + +# path to original working directory +# hydra hijacks working directory by changing it to the current log directory, +# so it's useful to have this path as a special variable +# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory +work_dir: ${hydra:runtime.cwd} + +# use `python run.py debug=true` for easy debugging! +# this will run 1 train, val and test loop with only 1 batch +# equivalent to running `python run.py trainer.fast_dev_run=true` +# (this is placed here just for easier access from command line) +debug: False + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: True diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index 13daba8..684b5bb 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,3 @@ -_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss -label_smoothing: 0.1 -vocab_size: 1006 +_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss +smoothing: 0.1 ignore_index: 1002 diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index 3070b56..2d1a03e 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -1,5 +1,6 @@ _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs -batch_size: 32 +batch_size: 4 num_workers: 12 train_fraction: 0.8 augment: true +pin_memory: false diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index 5afdf81..eecee8a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,8 +1,8 @@ _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 1.0e-3 total_steps: null -epochs: null -steps_per_epoch: null +epochs: 512 +steps_per_epoch: 4992 pct_start: 0.3 anneal_strategy: cos cycle_momentum: true diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index 3792523..48384f5 100644 --- a/training/conf/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.data.mappings.WordPieceMapping +_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping num_features: 1000 tokens: iamdb_1kwp_tokens_1000.txt lexicon: iamdb_1kwp_lex_1000.txt @@ -6,4 +6,4 @@ data_dir: null use_words: false prepend_wordsep: false special_tokens: [ <s>, <e>, <p> ] -extra_symbols: [ \n ] +extra_symbols: [ "\n" ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 6ffde4e..c190151 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,7 +1,7 @@ _target_: text_recognizer.models.transformer.TransformerLitModel interval: step monitor: val/loss -ignore_tokens: [ <s>, <e>, <p> ] +max_output_len: 451 start_token: <s> end_token: <e> pad_token: <p> diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index a97157d..f76e892 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] hidden_dim: 96 dropout_rate: 0.2 -max_output_len: 451 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 90b9d8a..eb80f64 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -18,3 +18,4 @@ ff_kwargs: dropout_rate: 0.2 cross_attend: true pre_norm: true +rotary_emb: null |