diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/datamodule/iam_extended_paragraphs.yaml | 5 | ||||
-rw-r--r-- | training/conf/dataset/iam_extended_paragraphs.yaml | 6 | ||||
-rw-r--r-- | training/conf/lr_scheduler/one_cycle.yaml | 23 | ||||
-rw-r--r-- | training/conf/model/lit_vqvae.yaml | 2 | ||||
-rw-r--r-- | training/run.py | 1 | ||||
-rw-r--r-- | training/utils.py | 2 |
6 files changed, 23 insertions, 16 deletions
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml new file mode 100644 index 0000000..3070b56 --- /dev/null +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -0,0 +1,5 @@ +_target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs +batch_size: 32 +num_workers: 12 +train_fraction: 0.8 +augment: true diff --git a/training/conf/dataset/iam_extended_paragraphs.yaml b/training/conf/dataset/iam_extended_paragraphs.yaml deleted file mode 100644 index 6439a15..0000000 --- a/training/conf/dataset/iam_extended_paragraphs.yaml +++ /dev/null @@ -1,6 +0,0 @@ -type: IAMExtendedParagraphs -args: - batch_size: 32 - num_workers: 12 - train_fraction: 0.8 - augment: true diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index 60a6f27..e8cb5c4 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,8 +1,15 @@ -type: OneCycleLR -args: - interval: step - max_lr: 1.0e-3 - three_phase: true - epochs: 64 - steps_per_epoch: 633 # num_samples / batch_size -monitor: val_loss +_target_: torch.optim.lr_scheduler.OneCycleLR +max_lr: 1.0e-3 +total_steps: None +epochs: None +steps_per_epoch: None +pct_start: 0.3 +anneal_strategy: 'cos' +cycle_momentum: True +base_momentum: 0.85 +max_momentum: 0.95 +div_factor: 25.0 +final_div_factor: 10000.0 +three_phase: true +last_epoch: -1 +verbose: false diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 7136dbd..6be37e5 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,3 +1,3 @@ -type: LitVQVAEModel +_target_: text_recognizer.models.vqvae.VQVAELitModel args: mapping: sentence_piece diff --git a/training/run.py b/training/run.py index 5f7c927..31da666 100644 --- a/training/run.py +++ b/training/run.py @@ -51,6 +51,7 @@ def run(config: DictConfig) -> Optional[float]: ) # Log hyperparameters + log.info("Logging hyperparameters") utils.log_hyperparameters(config=config, model=model, trainer=trainer) if config.debug: diff --git a/training/utils.py b/training/utils.py index 4c31dc3..140d97e 100644 --- a/training/utils.py +++ b/training/utils.py @@ -1,4 +1,4 @@ -"""Util functions for training hydra configs and pytorch lightning.""" +"""Util functions for training with hydra and pytorch lightning.""" from typing import Any, List, Type import warnings |