diff options
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index 4e045c7..607c3ce 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type import hydra from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl from torch import nn from tqdm import tqdm @@ -27,6 +27,12 @@ def _create_experiment_dir(config: DictConfig) -> Path: return log_dir +def save_config(config: DictConfig, log_dir: Path) -> None: + """Saves config to log directory.""" + with (log_dir / "config.yaml").open("r") as f: + OmegaConf.save(config=config, f=f) + + def _configure_logging(log_dir: Optional[Path], level: str) -> None: """Configure the loguru logger for output to terminal and disk.""" # Remove default logger to get tqdm to work properly. @@ -125,6 +131,9 @@ def run(config: DictConfig) -> None: # Load ligtning model. lit_model = _load_lit_model(lit_model_class, network, config) + # Save config to experiment dir. + save_config(config, log_dir) + trainer = pl.Trainer( **config.trainer.args, callbacks=callbacks, |