diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-27 17:04:42 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-27 17:04:42 +0200 |
commit | cafd6b8b10d804b3eee235652cb5218ef4a469b4 (patch) | |
tree | d292a7294e7de3c4cafe5ccee4ece8b56ce66b5e | |
parent | 0ba945e84d11a07ac95fdf8495f2ff278215adb9 (diff) |
Save experiment config to log dirlightning-refactor
-rw-r--r-- | training/run_experiment.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index 4e045c7..607c3ce 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type import hydra from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl from torch import nn from tqdm import tqdm @@ -27,6 +27,12 @@ def _create_experiment_dir(config: DictConfig) -> Path: return log_dir +def save_config(config: DictConfig, log_dir: Path) -> None: + """Saves config to log directory.""" + with (log_dir / "config.yaml").open("r") as f: + OmegaConf.save(config=config, f=f) + + def _configure_logging(log_dir: Optional[Path], level: str) -> None: """Configure the loguru logger for output to terminal and disk.""" # Remove default logger to get tqdm to work properly. @@ -125,6 +131,9 @@ def run(config: DictConfig) -> None: # Load ligtning model. lit_model = _load_lit_model(lit_model_class, network, config) + # Save config to experiment dir. + save_config(config, log_dir) + trainer = pl.Trainer( **config.trainer.args, callbacks=callbacks, |