summaryrefslogtreecommitdiff
path: root/training/conf
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /training/conf
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'training/conf')
-rw-r--r--training/conf/callbacks/checkpoint.yaml2
-rw-r--r--training/conf/callbacks/wandb_checkpoints.yaml (renamed from training/conf/callbacks/wandb/checkpoints.yaml)0
-rw-r--r--training/conf/callbacks/wandb_code.yaml (renamed from training/conf/callbacks/wandb/code.yaml)0
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml (renamed from training/conf/callbacks/wandb/image_reconstructions.yaml)0
-rw-r--r--training/conf/callbacks/wandb_ocr.yaml8
-rw-r--r--training/conf/callbacks/wandb_ocr_predictions.yaml (renamed from training/conf/callbacks/wandb/ocr_predictions.yaml)0
-rw-r--r--training/conf/callbacks/wandb_watch.yaml (renamed from training/conf/callbacks/wandb/watch.yaml)0
-rw-r--r--training/conf/config.yaml21
-rw-r--r--training/conf/criterion/label_smoothing.yaml5
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml3
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml4
-rw-r--r--training/conf/mapping/word_piece.yaml4
-rw-r--r--training/conf/model/lit_transformer.yaml2
-rw-r--r--training/conf/network/conv_transformer.yaml1
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml1
15 files changed, 35 insertions, 16 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index db34cb1..b4101d8 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -6,4 +6,4 @@ model_checkpoint:
mode: min # can be "max" or "min"
verbose: false
dirpath: checkpoints/
- filename: {epoch:02d}
+ filename: "{epoch:02d}"
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb_checkpoints.yaml
index a4a16ff..a4a16ff 100644
--- a/training/conf/callbacks/wandb/checkpoints.yaml
+++ b/training/conf/callbacks/wandb_checkpoints.yaml
diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb_code.yaml
index 35f6ea3..35f6ea3 100644
--- a/training/conf/callbacks/wandb/code.yaml
+++ b/training/conf/callbacks/wandb_code.yaml
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
index e69de29..e69de29 100644
--- a/training/conf/callbacks/wandb/image_reconstructions.yaml
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml
index efa3dda..9c9a6da 100644
--- a/training/conf/callbacks/wandb_ocr.yaml
+++ b/training/conf/callbacks/wandb_ocr.yaml
@@ -1,6 +1,6 @@
defaults:
- default
- - wandb/watch
- - wandb/code
- - wandb/checkpoints
- - wandb/ocr_predictions
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_ocr_predictions
diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb_ocr_predictions.yaml
index 573fa96..573fa96 100644
--- a/training/conf/callbacks/wandb/ocr_predictions.yaml
+++ b/training/conf/callbacks/wandb_ocr_predictions.yaml
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb_watch.yaml
index 511608c..511608c 100644
--- a/training/conf/callbacks/wandb/watch.yaml
+++ b/training/conf/callbacks/wandb_watch.yaml
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 93215ed..782bcbb 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,8 +1,9 @@
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
- - dataset: iam_extended_paragraphs
+ - datamodule: iam_extended_paragraphs
- hydra: default
+ - logger: wandb
- lr_scheduler: one_cycle
- mapping: word_piece
- model: lit_transformer
@@ -15,3 +16,21 @@ tune: false
train: true
test: true
logging: INFO
+
+# path to original working directory
+# hydra hijacks working directory by changing it to the current log directory,
+# so it's useful to have this path as a special variable
+# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
+work_dir: ${hydra:runtime.cwd}
+
+# use `python run.py debug=true` for easy debugging!
+# this will run 1 train, val and test loop with only 1 batch
+# equivalent to running `python run.py trainer.fast_dev_run=true`
+# (this is placed here just for easier access from command line)
+debug: False
+
+# pretty print config at the start of the run using Rich library
+print_config: True
+
+# disable python warnings if they annoy you
+ignore_warnings: True
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index 13daba8..684b5bb 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -1,4 +1,3 @@
-_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
-label_smoothing: 0.1
-vocab_size: 1006
+_target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss
+smoothing: 0.1
ignore_index: 1002
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index 3070b56..2d1a03e 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -1,5 +1,6 @@
_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs
-batch_size: 32
+batch_size: 4
num_workers: 12
train_fraction: 0.8
augment: true
+pin_memory: false
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index 5afdf81..eecee8a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,8 +1,8 @@
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 1.0e-3
total_steps: null
-epochs: null
-steps_per_epoch: null
+epochs: 512
+steps_per_epoch: 4992
pct_start: 0.3
anneal_strategy: cos
cycle_momentum: true
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 3792523..48384f5 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.data.mappings.WordPieceMapping
+_target_: text_recognizer.data.word_piece_mapping.WordPieceMapping
num_features: 1000
tokens: iamdb_1kwp_tokens_1000.txt
lexicon: iamdb_1kwp_lex_1000.txt
@@ -6,4 +6,4 @@ data_dir: null
use_words: false
prepend_wordsep: false
special_tokens: [ <s>, <e>, <p> ]
-extra_symbols: [ \n ]
+extra_symbols: [ "\n" ]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 6ffde4e..c190151 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,7 +1,7 @@
_target_: text_recognizer.models.transformer.TransformerLitModel
interval: step
monitor: val/loss
-ignore_tokens: [ <s>, <e>, <p> ]
+max_output_len: 451
start_token: <s>
end_token: <e>
pad_token: <p>
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index a97157d..f76e892 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -6,6 +6,5 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: 96
dropout_rate: 0.2
-max_output_len: 451
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 90b9d8a..eb80f64 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -18,3 +18,4 @@ ff_kwargs:
dropout_rate: 0.2
cross_attend: true
pre_norm: true
+rotary_emb: null