summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index 4e045c7..607c3ce 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type
import hydra
from loguru import logger
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from torch import nn
from tqdm import tqdm
@@ -27,6 +27,12 @@ def _create_experiment_dir(config: DictConfig) -> Path:
return log_dir
+def save_config(config: DictConfig, log_dir: Path) -> None:
+ """Saves config to log directory."""
+ with (log_dir / "config.yaml").open("r") as f:
+ OmegaConf.save(config=config, f=f)
+
+
def _configure_logging(log_dir: Optional[Path], level: str) -> None:
"""Configure the loguru logger for output to terminal and disk."""
# Remove default logger to get tqdm to work properly.
@@ -125,6 +131,9 @@ def run(config: DictConfig) -> None:
# Load ligtning model.
lit_model = _load_lit_model(lit_model_class, network, config)
+ # Save config to experiment dir.
+ save_config(config, log_dir)
+
trainer = pl.Trainer(
**config.trainer.args,
callbacks=callbacks,