summaryrefslogtreecommitdiff
path: root/src/training/run_sweep.py
blob: a578592bbf3e400c61650b48cb2a36a104022c84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""W&B Sweep Functionality."""
from ast import literal_eval
import json
import os
from pathlib import Path
import signal
import subprocess  # nosec
import sys
from typing import Dict, List, Tuple

import click
import yaml

EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"


def load_config() -> Dict:
    """Load base hyperparameter config."""
    with open(str(EXPERIMENTS_DIRNAME / "default_config_emnist.yml"), "r") as f:
        default_config = yaml.safe_load(f)
    return default_config


def args_to_json(
    default_config: dict, preserve_args: tuple = ("gpu", "save")
) -> Tuple[dict, list]:
    """Convert command line arguments to nested config values.

    i.e. run_sweep.py --dataset_args.foo=1.7
    {
        "dataset_args": {
            "foo": 1.7
        }
    }

    Args:
        default_config (dict): The base config used for every experiment.
        preserve_args (tuple): Arguments preserved for all runs. Defaults to ("gpu", "save").

    Returns:
        Tuple[dict, list]: Tuple of config dictionary and list of arguments.

    """

    args = []
    config = default_config.copy()
    key, val = None, None
    for arg in sys.argv[1:]:
        if "=" in arg:
            key, val = arg.split("=")
        elif key:
            val = arg
        else:
            key = arg
        if key and val:
            parsed_key = key.lstrip("-").split(".")
            if parsed_key[0] in preserve_args:
                args.append("--{}={}".format(parsed_key[0], val))
            else:
                nested = config
                for level in parsed_key[:-1]:
                    nested[level] = config.get(level, {})
                    nested = nested[level]
                try:
                    # Convert numerics to floats / ints
                    val = literal_eval(val)
                except ValueError:
                    pass
                nested[parsed_key[-1]] = val
            key, val = None, None
    return config, args


def main() -> None:
    """Runs a W&B sweep."""
    default_config = load_config()
    config, args = args_to_json(default_config)
    env = {
        k: v for k, v in os.environ.items() if k not in ("WANDB_PROGRAM", "WANDB_ARGS")
    }
    # pylint: disable=subprocess-popen-preexec-fn
    run = subprocess.Popen(
        ["python", "training/run_experiment.py", *args, json.dumps(config)],
        env=env,
        preexec_fn=os.setsid,
    )  # nosec
    signal.signal(signal.SIGTERM, lambda *args: run.terminate())
    run.wait()


if __name__ == "__main__":
    main()